|
|
|
|
|
""" |
|
@Author : Qingping Zheng |
|
@Contact : [email protected] |
|
@File : edges.py |
|
@Time : 10/01/21 00:00 PM |
|
@Desc : |
|
@License : Licensed under the Apache License, Version 2.0 (the "License"); |
|
@Copyright : Copyright 2022 The Authors. All Rights Reserved. |
|
""" |
|
from __future__ import absolute_import |
|
from __future__ import division |
|
from __future__ import print_function |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
import torch.nn as nn |
|
|
|
from inplace_abn import InPlaceABNSync |
|
|
|
|
|
class Edges(nn.Module): |
|
|
|
def __init__(self, abn=InPlaceABNSync, in_fea=[256,512,1024], mid_fea=256, out_fea=2): |
|
super(Edges, self).__init__() |
|
|
|
self.conv1 = nn.Sequential( |
|
nn.Conv2d(in_fea[0], mid_fea, kernel_size=1, padding=0, dilation=1, bias=False), |
|
abn(mid_fea) |
|
) |
|
self.conv2 = nn.Sequential( |
|
nn.Conv2d(in_fea[1], mid_fea, kernel_size=1, padding=0, dilation=1, bias=False), |
|
abn(mid_fea) |
|
) |
|
self.conv3 = nn.Sequential( |
|
nn.Conv2d(in_fea[2], mid_fea, kernel_size=1, padding=0, dilation=1, bias=False), |
|
abn(mid_fea) |
|
) |
|
self.conv4 = nn.Conv2d(mid_fea,out_fea, kernel_size=3, padding=1, dilation=1, bias=True) |
|
self.conv5_b = nn.Conv2d(out_fea*3,2, kernel_size=1, padding=0, dilation=1, bias=True) |
|
self.conv5 = nn.Conv2d(out_fea*3,out_fea, kernel_size=1, padding=0, dilation=1, bias=True) |
|
|
|
|
|
def forward(self, x1, x2, x3): |
|
_, _, h, w = x1.size() |
|
|
|
edge1_fea = self.conv1(x1) |
|
edge1 = self.conv4(edge1_fea) |
|
edge2_fea = self.conv2(x2) |
|
edge2 = self.conv4(edge2_fea) |
|
edge3_fea = self.conv3(x3) |
|
edge3 = self.conv4(edge3_fea) |
|
|
|
edge2_fea = F.interpolate(edge2_fea, size=(h, w), mode='bilinear',align_corners=True) |
|
edge3_fea = F.interpolate(edge3_fea, size=(h, w), mode='bilinear',align_corners=True) |
|
edge2 = F.interpolate(edge2, size=(h, w), mode='bilinear',align_corners=True) |
|
edge3 = F.interpolate(edge3, size=(h, w), mode='bilinear',align_corners=True) |
|
|
|
edge = torch.cat([edge1, edge2, edge3], dim=1) |
|
edge_fea = torch.cat([edge1_fea, edge2_fea, edge3_fea], dim=1) |
|
semantic_edge = self.conv5(edge) |
|
binary_edge = self.conv5_b(edge) |
|
|
|
return binary_edge, semantic_edge, edge_fea |
|
|
|
|