|
import math |
|
import numpy as np |
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
from .common import * |
|
|
|
|
|
class Encoder(nn.Module): |
|
def __init__(self, in_channels=3, depth=3): |
|
super(Encoder, self).__init__() |
|
|
|
|
|
|
|
|
|
self.shuffler = PixelShuffle(1 / 2**depth) |
|
|
|
relu = nn.LeakyReLU(0.2, True) |
|
|
|
|
|
self.interpolate = Interpolation(5, 12, in_channels * (4**depth), act=relu) |
|
|
|
def forward(self, x1, x2): |
|
""" |
|
Encoder: Shuffle-spread --> Feature Fusion --> Return fused features |
|
""" |
|
feats1 = self.shuffler(x1) |
|
feats2 = self.shuffler(x2) |
|
|
|
feats = self.interpolate(feats1, feats2) |
|
|
|
return feats |
|
|
|
|
|
class Decoder(nn.Module): |
|
def __init__(self, depth=3): |
|
super(Decoder, self).__init__() |
|
|
|
|
|
|
|
self.shuffler = PixelShuffle(2**depth) |
|
|
|
def forward(self, feats): |
|
out = self.shuffler(feats) |
|
return out |
|
|
|
|
|
class CAIN(nn.Module): |
|
def __init__(self, depth=3): |
|
super(CAIN, self).__init__() |
|
|
|
self.encoder = Encoder(in_channels=3, depth=depth) |
|
self.decoder = Decoder(depth=depth) |
|
|
|
def forward(self, x1, x2): |
|
x1, m1 = sub_mean(x1) |
|
x2, m2 = sub_mean(x2) |
|
|
|
if not self.training: |
|
paddingInput, paddingOutput = InOutPaddings(x1) |
|
x1 = paddingInput(x1) |
|
x2 = paddingInput(x2) |
|
|
|
feats = self.encoder(x1, x2) |
|
out = self.decoder(feats) |
|
|
|
if not self.training: |
|
out = paddingOutput(out) |
|
|
|
mi = (m1 + m2) / 2 |
|
out += mi |
|
|
|
return out, feats |