File size: 2,497 Bytes
332190f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 |
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
"""
@Author : Qingping Zheng
@Contact : [email protected]
@File : edges.py
@Time : 10/01/21 00:00 PM
@Desc :
@License : Licensed under the Apache License, Version 2.0 (the "License");
@Copyright : Copyright 2022 The Authors. All Rights Reserved.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import torch
import torch.nn.functional as F
import torch.nn as nn
from inplace_abn import InPlaceABNSync
class Edges(nn.Module):
def __init__(self, abn=InPlaceABNSync, in_fea=[256,512,1024], mid_fea=256, out_fea=2):
super(Edges, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(in_fea[0], mid_fea, kernel_size=1, padding=0, dilation=1, bias=False),
abn(mid_fea)
)
self.conv2 = nn.Sequential(
nn.Conv2d(in_fea[1], mid_fea, kernel_size=1, padding=0, dilation=1, bias=False),
abn(mid_fea)
)
self.conv3 = nn.Sequential(
nn.Conv2d(in_fea[2], mid_fea, kernel_size=1, padding=0, dilation=1, bias=False),
abn(mid_fea)
)
self.conv4 = nn.Conv2d(mid_fea,out_fea, kernel_size=3, padding=1, dilation=1, bias=True)
self.conv5_b = nn.Conv2d(out_fea*3,2, kernel_size=1, padding=0, dilation=1, bias=True)
self.conv5 = nn.Conv2d(out_fea*3,out_fea, kernel_size=1, padding=0, dilation=1, bias=True)
def forward(self, x1, x2, x3):
_, _, h, w = x1.size()
edge1_fea = self.conv1(x1)
edge1 = self.conv4(edge1_fea)
edge2_fea = self.conv2(x2)
edge2 = self.conv4(edge2_fea)
edge3_fea = self.conv3(x3)
edge3 = self.conv4(edge3_fea)
edge2_fea = F.interpolate(edge2_fea, size=(h, w), mode='bilinear',align_corners=True)
edge3_fea = F.interpolate(edge3_fea, size=(h, w), mode='bilinear',align_corners=True)
edge2 = F.interpolate(edge2, size=(h, w), mode='bilinear',align_corners=True)
edge3 = F.interpolate(edge3, size=(h, w), mode='bilinear',align_corners=True)
edge = torch.cat([edge1, edge2, edge3], dim=1)
edge_fea = torch.cat([edge1_fea, edge2_fea, edge3_fea], dim=1)
semantic_edge = self.conv5(edge)
binary_edge = self.conv5_b(edge)
return binary_edge, semantic_edge, edge_fea
|