File size: 1,712 Bytes
332190f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 |
#!/usr/bin/env python
# -*- encoding: utf-8 -*-
"""
@Author : Qingping Zheng
@Contact : [email protected]
@File : parsing.py
@Time : 10/01/21 00:00 PM
@Desc :
@License : Licensed under the Apache License, Version 2.0 (the "License");
@Copyright : Copyright 2022 The Authors. All Rights Reserved.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import torch
import torch.nn.functional as F
import torch.nn as nn
from inplace_abn import InPlaceABNSync
class Parsing(nn.Module):
def __init__(self, in_plane1, in_plane2, num_classes, abn=InPlaceABNSync):
super(Parsing, self).__init__()
self.conv1 = nn.Sequential(
nn.Conv2d(in_plane1, 256, kernel_size=1, padding=0, dilation=1, bias=False),
abn(256)
)
self.conv2 = nn.Sequential(
nn.Conv2d(in_plane2, 48, kernel_size=1, stride=1, padding=0, dilation=1, bias=False),
abn(48)
)
self.conv3 = nn.Sequential(
nn.Conv2d(304, 256, kernel_size=1, padding=0, dilation=1, bias=False),
abn(256),
nn.Conv2d(256, 256, kernel_size=1, padding=0, dilation=1, bias=False),
abn(256)
)
self.conv4 = nn.Conv2d(256, num_classes, kernel_size=1, padding=0, dilation=1, bias=True)
def forward(self, xt, xl):
_, _, h, w = xl.size()
xt = F.interpolate(self.conv1(xt), size=(h, w), mode='bilinear', align_corners=True)
xl = self.conv2(xl)
x = torch.cat([xt, xl], dim=1)
x = self.conv3(x)
seg = self.conv4(x)
return seg, x
|