|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from torch.nn import functional as F |
|
from torch.utils.model_zoo import load_url as load_state_dict_from_url |
|
import cv2 |
|
import math |
|
import numpy |
|
import numpy as np |
|
import PIL |
|
import PIL.Image |
|
import re |
|
import sys |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import torch.optim as optim |
|
import torch.optim.lr_scheduler as lrs |
|
from vfi_models.ops import FunctionCorrelation, FunctionAdaCoF, ModuleSoftsplat |
|
from vfi_utils import get_ckpt_container_path |
|
import pathlib |
|
MODEL_TYPE = pathlib.Path(__file__).parent.name |
|
|
|
|
|
def identity(x): |
|
return x |
|
|
|
|
|
def backwarp(tenInput, tenFlow): |
|
backwarp_tenGrid = {} |
|
backwarp_tenPartial = {} |
|
if str(tenFlow.shape) not in backwarp_tenGrid: |
|
tenHor = ( |
|
torch.linspace( |
|
-1.0 + (1.0 / tenFlow.shape[3]), |
|
1.0 - (1.0 / tenFlow.shape[3]), |
|
tenFlow.shape[3], |
|
) |
|
.view(1, 1, 1, -1) |
|
.expand(-1, -1, tenFlow.shape[2], -1) |
|
) |
|
tenVer = ( |
|
torch.linspace( |
|
-1.0 + (1.0 / tenFlow.shape[2]), |
|
1.0 - (1.0 / tenFlow.shape[2]), |
|
tenFlow.shape[2], |
|
) |
|
.view(1, 1, -1, 1) |
|
.expand(-1, -1, -1, tenFlow.shape[3]) |
|
) |
|
|
|
backwarp_tenGrid[str(tenFlow.shape)] = torch.cat([tenHor, tenVer], 1).cuda() |
|
|
|
|
|
if str(tenFlow.shape) not in backwarp_tenPartial: |
|
backwarp_tenPartial[str(tenFlow.shape)] = tenFlow.new_ones( |
|
[tenFlow.shape[0], 1, tenFlow.shape[2], tenFlow.shape[3]] |
|
) |
|
|
|
|
|
tenFlow = torch.cat( |
|
[ |
|
tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0), |
|
tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0), |
|
], |
|
1, |
|
) |
|
tenInput = torch.cat([tenInput, backwarp_tenPartial[str(tenFlow.shape)]], 1) |
|
|
|
tenOutput = torch.nn.functional.grid_sample( |
|
input=tenInput, |
|
grid=(backwarp_tenGrid[str(tenFlow.shape)] + tenFlow).permute(0, 2, 3, 1), |
|
mode="bilinear", |
|
padding_mode="zeros", |
|
align_corners=False, |
|
) |
|
|
|
tenMask = tenOutput[:, -1:, :, :] |
|
tenMask[tenMask > 0.999] = 1.0 |
|
tenMask[tenMask < 1.0] = 0.0 |
|
|
|
return tenOutput[:, :-1, :, :] * tenMask |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class PWCNet(torch.nn.Module): |
|
def __init__(self): |
|
super(PWCNet, self).__init__() |
|
|
|
class Extractor(torch.nn.Module): |
|
def __init__(self): |
|
super(Extractor, self).__init__() |
|
|
|
self.netOne = torch.nn.Sequential( |
|
torch.nn.Conv2d( |
|
in_channels=3, |
|
out_channels=16, |
|
kernel_size=3, |
|
stride=2, |
|
padding=1, |
|
), |
|
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
|
torch.nn.Conv2d( |
|
in_channels=16, |
|
out_channels=16, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1, |
|
), |
|
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
|
torch.nn.Conv2d( |
|
in_channels=16, |
|
out_channels=16, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1, |
|
), |
|
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
|
) |
|
|
|
self.netTwo = torch.nn.Sequential( |
|
torch.nn.Conv2d( |
|
in_channels=16, |
|
out_channels=32, |
|
kernel_size=3, |
|
stride=2, |
|
padding=1, |
|
), |
|
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
|
torch.nn.Conv2d( |
|
in_channels=32, |
|
out_channels=32, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1, |
|
), |
|
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
|
torch.nn.Conv2d( |
|
in_channels=32, |
|
out_channels=32, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1, |
|
), |
|
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
|
) |
|
|
|
self.netThr = torch.nn.Sequential( |
|
torch.nn.Conv2d( |
|
in_channels=32, |
|
out_channels=64, |
|
kernel_size=3, |
|
stride=2, |
|
padding=1, |
|
), |
|
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
|
torch.nn.Conv2d( |
|
in_channels=64, |
|
out_channels=64, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1, |
|
), |
|
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
|
torch.nn.Conv2d( |
|
in_channels=64, |
|
out_channels=64, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1, |
|
), |
|
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
|
) |
|
|
|
self.netFou = torch.nn.Sequential( |
|
torch.nn.Conv2d( |
|
in_channels=64, |
|
out_channels=96, |
|
kernel_size=3, |
|
stride=2, |
|
padding=1, |
|
), |
|
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
|
torch.nn.Conv2d( |
|
in_channels=96, |
|
out_channels=96, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1, |
|
), |
|
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
|
torch.nn.Conv2d( |
|
in_channels=96, |
|
out_channels=96, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1, |
|
), |
|
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
|
) |
|
|
|
self.netFiv = torch.nn.Sequential( |
|
torch.nn.Conv2d( |
|
in_channels=96, |
|
out_channels=128, |
|
kernel_size=3, |
|
stride=2, |
|
padding=1, |
|
), |
|
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
|
torch.nn.Conv2d( |
|
in_channels=128, |
|
out_channels=128, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1, |
|
), |
|
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
|
torch.nn.Conv2d( |
|
in_channels=128, |
|
out_channels=128, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1, |
|
), |
|
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
|
) |
|
|
|
self.netSix = torch.nn.Sequential( |
|
torch.nn.Conv2d( |
|
in_channels=128, |
|
out_channels=196, |
|
kernel_size=3, |
|
stride=2, |
|
padding=1, |
|
), |
|
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
|
torch.nn.Conv2d( |
|
in_channels=196, |
|
out_channels=196, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1, |
|
), |
|
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
|
torch.nn.Conv2d( |
|
in_channels=196, |
|
out_channels=196, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1, |
|
), |
|
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
|
) |
|
|
|
|
|
|
|
def forward(self, tenInput): |
|
tenOne = self.netOne(tenInput) |
|
tenTwo = self.netTwo(tenOne) |
|
tenThr = self.netThr(tenTwo) |
|
tenFou = self.netFou(tenThr) |
|
tenFiv = self.netFiv(tenFou) |
|
tenSix = self.netSix(tenFiv) |
|
|
|
return [tenOne, tenTwo, tenThr, tenFou, tenFiv, tenSix] |
|
|
|
|
|
|
|
|
|
|
|
class Decoder(torch.nn.Module): |
|
def __init__(self, intLevel): |
|
super(Decoder, self).__init__() |
|
|
|
intPrevious = [ |
|
None, |
|
None, |
|
81 + 32 + 2 + 2, |
|
81 + 64 + 2 + 2, |
|
81 + 96 + 2 + 2, |
|
81 + 128 + 2 + 2, |
|
81, |
|
None, |
|
][intLevel + 1] |
|
intCurrent = [ |
|
None, |
|
None, |
|
81 + 32 + 2 + 2, |
|
81 + 64 + 2 + 2, |
|
81 + 96 + 2 + 2, |
|
81 + 128 + 2 + 2, |
|
81, |
|
None, |
|
][intLevel + 0] |
|
|
|
if intLevel < 6: |
|
self.netUpflow = torch.nn.ConvTranspose2d( |
|
in_channels=2, |
|
out_channels=2, |
|
kernel_size=4, |
|
stride=2, |
|
padding=1, |
|
) |
|
if intLevel < 6: |
|
self.netUpfeat = torch.nn.ConvTranspose2d( |
|
in_channels=intPrevious + 128 + 128 + 96 + 64 + 32, |
|
out_channels=2, |
|
kernel_size=4, |
|
stride=2, |
|
padding=1, |
|
) |
|
if intLevel < 6: |
|
self.fltBackwarp = [None, None, None, 5.0, 2.5, 1.25, 0.625, None][ |
|
intLevel + 1 |
|
] |
|
|
|
self.netOne = torch.nn.Sequential( |
|
torch.nn.Conv2d( |
|
in_channels=intCurrent, |
|
out_channels=128, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1, |
|
), |
|
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
|
) |
|
|
|
self.netTwo = torch.nn.Sequential( |
|
torch.nn.Conv2d( |
|
in_channels=intCurrent + 128, |
|
out_channels=128, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1, |
|
), |
|
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
|
) |
|
|
|
self.netThr = torch.nn.Sequential( |
|
torch.nn.Conv2d( |
|
in_channels=intCurrent + 128 + 128, |
|
out_channels=96, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1, |
|
), |
|
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
|
) |
|
|
|
self.netFou = torch.nn.Sequential( |
|
torch.nn.Conv2d( |
|
in_channels=intCurrent + 128 + 128 + 96, |
|
out_channels=64, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1, |
|
), |
|
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
|
) |
|
|
|
self.netFiv = torch.nn.Sequential( |
|
torch.nn.Conv2d( |
|
in_channels=intCurrent + 128 + 128 + 96 + 64, |
|
out_channels=32, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1, |
|
), |
|
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
|
) |
|
|
|
self.netSix = torch.nn.Sequential( |
|
torch.nn.Conv2d( |
|
in_channels=intCurrent + 128 + 128 + 96 + 64 + 32, |
|
out_channels=2, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1, |
|
) |
|
) |
|
|
|
|
|
|
|
def forward(self, tenFirst, tenSecond, objPrevious): |
|
tenFlow = None |
|
tenFeat = None |
|
|
|
if objPrevious is None: |
|
tenFlow = None |
|
tenFeat = None |
|
|
|
tenVolume = torch.nn.functional.leaky_relu( |
|
input=FunctionCorrelation( |
|
tenFirst=tenFirst, tenSecond=tenSecond |
|
), |
|
negative_slope=0.1, |
|
inplace=False, |
|
) |
|
|
|
tenFeat = torch.cat([tenVolume], 1) |
|
|
|
elif objPrevious is not None: |
|
tenFlow = self.netUpflow(objPrevious["tenFlow"]) |
|
tenFeat = self.netUpfeat(objPrevious["tenFeat"]) |
|
|
|
tenVolume = torch.nn.functional.leaky_relu( |
|
input=FunctionCorrelation( |
|
tenFirst=tenFirst, |
|
tenSecond=backwarp( |
|
tenInput=tenSecond, tenFlow=tenFlow * self.fltBackwarp |
|
), |
|
), |
|
negative_slope=0.1, |
|
inplace=False, |
|
) |
|
|
|
tenFeat = torch.cat([tenVolume, tenFirst, tenFlow, tenFeat], 1) |
|
|
|
|
|
|
|
tenFeat = torch.cat([self.netOne(tenFeat), tenFeat], 1) |
|
tenFeat = torch.cat([self.netTwo(tenFeat), tenFeat], 1) |
|
tenFeat = torch.cat([self.netThr(tenFeat), tenFeat], 1) |
|
tenFeat = torch.cat([self.netFou(tenFeat), tenFeat], 1) |
|
tenFeat = torch.cat([self.netFiv(tenFeat), tenFeat], 1) |
|
|
|
tenFlow = self.netSix(tenFeat) |
|
|
|
return {"tenFlow": tenFlow, "tenFeat": tenFeat} |
|
|
|
|
|
|
|
|
|
|
|
class Refiner(torch.nn.Module): |
|
def __init__(self): |
|
super(Refiner, self).__init__() |
|
|
|
self.netMain = torch.nn.Sequential( |
|
torch.nn.Conv2d( |
|
in_channels=81 + 32 + 2 + 2 + 128 + 128 + 96 + 64 + 32, |
|
out_channels=128, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1, |
|
dilation=1, |
|
), |
|
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
|
torch.nn.Conv2d( |
|
in_channels=128, |
|
out_channels=128, |
|
kernel_size=3, |
|
stride=1, |
|
padding=2, |
|
dilation=2, |
|
), |
|
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
|
torch.nn.Conv2d( |
|
in_channels=128, |
|
out_channels=128, |
|
kernel_size=3, |
|
stride=1, |
|
padding=4, |
|
dilation=4, |
|
), |
|
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
|
torch.nn.Conv2d( |
|
in_channels=128, |
|
out_channels=96, |
|
kernel_size=3, |
|
stride=1, |
|
padding=8, |
|
dilation=8, |
|
), |
|
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
|
torch.nn.Conv2d( |
|
in_channels=96, |
|
out_channels=64, |
|
kernel_size=3, |
|
stride=1, |
|
padding=16, |
|
dilation=16, |
|
), |
|
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
|
torch.nn.Conv2d( |
|
in_channels=64, |
|
out_channels=32, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1, |
|
dilation=1, |
|
), |
|
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
|
torch.nn.Conv2d( |
|
in_channels=32, |
|
out_channels=2, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1, |
|
dilation=1, |
|
), |
|
) |
|
|
|
|
|
|
|
def forward(self, tenInput): |
|
return self.netMain(tenInput) |
|
|
|
|
|
|
|
|
|
|
|
self.netExtractor = Extractor() |
|
|
|
self.netTwo = Decoder(2) |
|
self.netThr = Decoder(3) |
|
self.netFou = Decoder(4) |
|
self.netFiv = Decoder(5) |
|
self.netSix = Decoder(6) |
|
|
|
self.netRefiner = Refiner() |
|
|
|
self.load_state_dict( |
|
{ |
|
strKey.replace("module", "net"): tenWeight |
|
for strKey, tenWeight in torch.hub.load_state_dict_from_url( |
|
url="http://content.sniklaus.com/github/pytorch-pwc/network-" |
|
+ "default" |
|
+ ".pytorch", |
|
model_dir=get_ckpt_container_path(MODEL_TYPE) |
|
).items() |
|
} |
|
) |
|
|
|
|
|
|
|
def forward(self, tenFirst, tenSecond, *args): |
|
|
|
if len(args) == 0: |
|
tenFirst = self.netExtractor(tenFirst) |
|
tenSecond = self.netExtractor(tenSecond) |
|
else: |
|
tenFirst, tenSecond = args |
|
|
|
objEstimate = self.netSix(tenFirst[-1], tenSecond[-1], None) |
|
objEstimate = self.netFiv(tenFirst[-2], tenSecond[-2], objEstimate) |
|
objEstimate = self.netFou(tenFirst[-3], tenSecond[-3], objEstimate) |
|
objEstimate = self.netThr(tenFirst[-4], tenSecond[-4], objEstimate) |
|
objEstimate = self.netTwo(tenFirst[-5], tenSecond[-5], objEstimate) |
|
|
|
return objEstimate["tenFlow"] + self.netRefiner(objEstimate["tenFeat"]) |
|
|
|
|
|
|
|
def extract_pyramid(self, tenFirst, tenSecond): |
|
return self.netExtractor(tenFirst), self.netExtractor(tenSecond) |
|
|
|
def extract_pyramid_single(self, tenFirst): |
|
return self.netExtractor(tenFirst) |
|
|
|
|
|
|
|
|
|
netNetwork = None |
|
|
|
|
|
|
|
|
|
def estimate(tenFirst, tenSecond): |
|
global netNetwork |
|
|
|
if netNetwork is None: |
|
netNetwork = Network().cuda().eval() |
|
|
|
|
|
assert tenFirst.shape[1] == tenSecond.shape[1] |
|
assert tenFirst.shape[2] == tenSecond.shape[2] |
|
|
|
intWidth = tenFirst.shape[2] |
|
intHeight = tenFirst.shape[1] |
|
|
|
assert ( |
|
intWidth == 1024 |
|
) |
|
assert ( |
|
intHeight == 436 |
|
) |
|
|
|
tenPreprocessedFirst = tenFirst.cuda().view(1, 3, intHeight, intWidth) |
|
tenPreprocessedSecond = tenSecond.cuda().view(1, 3, intHeight, intWidth) |
|
|
|
intPreprocessedWidth = int(math.floor(math.ceil(intWidth / 64.0) * 64.0)) |
|
intPreprocessedHeight = int(math.floor(math.ceil(intHeight / 64.0) * 64.0)) |
|
|
|
tenPreprocessedFirst = torch.nn.functional.interpolate( |
|
input=tenPreprocessedFirst, |
|
size=(intPreprocessedHeight, intPreprocessedWidth), |
|
mode="bilinear", |
|
align_corners=False, |
|
) |
|
tenPreprocessedSecond = torch.nn.functional.interpolate( |
|
input=tenPreprocessedSecond, |
|
size=(intPreprocessedHeight, intPreprocessedWidth), |
|
mode="bilinear", |
|
align_corners=False, |
|
) |
|
|
|
tenFlow = 20.0 * torch.nn.functional.interpolate( |
|
input=netNetwork(tenPreprocessedFirst, tenPreprocessedSecond), |
|
size=(intHeight, intWidth), |
|
mode="bilinear", |
|
align_corners=False, |
|
) |
|
|
|
tenFlow[:, 0, :, :] *= float(intWidth) / float(intPreprocessedWidth) |
|
tenFlow[:, 1, :, :] *= float(intHeight) / float(intPreprocessedHeight) |
|
|
|
return tenFlow[0, :, :, :].cpu() |
|
|
|
|
|
|
|
|
|
|
|
class Upsampler_8tap(nn.Module): |
|
def __init__(self): |
|
super(Upsampler_8tap, self).__init__() |
|
filt_8tap = torch.tensor([[-1, 4, -11, 40, 40, -11, 4, -1]]).div(64) |
|
self.filter = nn.Parameter(filt_8tap.repeat(3, 1, 1, 1), requires_grad=False) |
|
|
|
def forward(self, im): |
|
b, c, h, w = im.shape |
|
im_up = torch.zeros(b, c, h * 2, w * 2).to(im.device) |
|
im_up[:, :, ::2, ::2] = im |
|
|
|
p = (8 - 1) // 2 |
|
im_up_row = F.conv2d( |
|
F.pad(im, pad=(p, p + 1, 0, 0), mode="reflect"), self.filter, groups=3 |
|
) |
|
im_up[:, :, 0::2, 1::2] = im_up_row |
|
im_up_col = torch.transpose( |
|
F.conv2d( |
|
F.pad(torch.transpose(im, 2, 3), pad=(p, p + 1, 0, 0), mode="reflect"), |
|
self.filter, |
|
groups=3, |
|
), |
|
2, |
|
3, |
|
) |
|
im_up[:, :, 1::2, 0::2] = im_up_col |
|
im_up_cross = F.conv2d( |
|
F.pad(im_up[:, :, 1::2, ::2], pad=(p, p + 1, 0, 0), mode="reflect"), |
|
self.filter, |
|
groups=3, |
|
) |
|
im_up[:, :, 1::2, 1::2] = im_up_cross |
|
return im_up |
|
|
|
|
|
|
|
|
|
model_urls = { |
|
"r3d_18": "https://download.pytorch.org/models/r3d_18-b3b3357e.pth", |
|
"mc3_18": "https://download.pytorch.org/models/mc3_18-a90a0ba3.pth", |
|
"r2plus1d_18": "https://download.pytorch.org/models/r2plus1d_18-91a641e6.pth", |
|
} |
|
|
|
|
|
class Conv3DSimple(nn.Conv3d): |
|
def __init__(self, in_planes, out_planes, midplanes=None, stride=1, padding=1): |
|
super(Conv3DSimple, self).__init__( |
|
in_channels=in_planes, |
|
out_channels=out_planes, |
|
kernel_size=(3, 3, 3), |
|
stride=stride, |
|
padding=padding, |
|
bias=False, |
|
) |
|
|
|
@staticmethod |
|
def get_downsample_stride(stride, temporal_stride): |
|
if temporal_stride: |
|
return (temporal_stride, stride, stride) |
|
else: |
|
return (stride, stride, stride) |
|
|
|
|
|
class Conv2Plus1D(nn.Sequential): |
|
def __init__(self, in_planes, out_planes, midplanes, stride=1, padding=1): |
|
super(Conv2Plus1D, self).__init__( |
|
nn.Conv3d( |
|
in_planes, |
|
midplanes, |
|
kernel_size=(1, 3, 3), |
|
stride=(1, stride, stride), |
|
padding=(0, padding, padding), |
|
bias=False, |
|
), |
|
batchnorm(midplanes), |
|
nn.ReLU(inplace=True), |
|
nn.Conv3d( |
|
midplanes, |
|
out_planes, |
|
kernel_size=(3, 1, 1), |
|
stride=(stride, 1, 1), |
|
padding=(padding, 0, 0), |
|
bias=False, |
|
), |
|
) |
|
|
|
@staticmethod |
|
def get_downsample_stride(stride): |
|
return stride, stride, stride |
|
|
|
|
|
class Conv3DNoTemporal(nn.Conv3d): |
|
def __init__(self, in_planes, out_planes, midplanes=None, stride=1, padding=1): |
|
super(Conv3DNoTemporal, self).__init__( |
|
in_channels=in_planes, |
|
out_channels=out_planes, |
|
kernel_size=(1, 3, 3), |
|
stride=(1, stride, stride), |
|
padding=(0, padding, padding), |
|
bias=False, |
|
) |
|
|
|
@staticmethod |
|
def get_downsample_stride(stride): |
|
return 1, stride, stride |
|
|
|
|
|
class SEGating(nn.Module): |
|
def __init__(self, inplanes, reduction=16): |
|
super().__init__() |
|
|
|
self.pool = nn.AdaptiveAvgPool3d(1) |
|
self.attn_layer = nn.Sequential( |
|
nn.Conv3d(inplanes, inplanes, kernel_size=1, stride=1, bias=True), |
|
nn.Sigmoid(), |
|
) |
|
|
|
def forward(self, x): |
|
out = self.pool(x) |
|
y = self.attn_layer(out) |
|
return x * y |
|
|
|
|
|
class BasicBlock(nn.Module): |
|
expansion = 1 |
|
|
|
def __init__(self, inplanes, planes, conv_builder, stride=1, downsample=None): |
|
midplanes = (inplanes * planes * 3 * 3 * 3) // (inplanes * 3 * 3 + 3 * planes) |
|
|
|
super(BasicBlock, self).__init__() |
|
self.conv1 = nn.Sequential( |
|
conv_builder(inplanes, planes, midplanes, stride), |
|
batchnorm(planes), |
|
nn.ReLU(inplace=True), |
|
) |
|
self.conv2 = nn.Sequential( |
|
conv_builder(planes, planes, midplanes), batchnorm(planes) |
|
) |
|
self.fg = SEGating(planes) |
|
self.relu = nn.ReLU(inplace=True) |
|
self.downsample = downsample |
|
self.stride = stride |
|
|
|
def forward(self, x): |
|
residual = x |
|
out = self.conv1(x) |
|
out = self.conv2(out) |
|
out = self.fg(out) |
|
if self.downsample is not None: |
|
residual = self.downsample(x) |
|
|
|
out += residual |
|
out = self.relu(out) |
|
|
|
return out |
|
|
|
|
|
class Bottleneck(nn.Module): |
|
expansion = 4 |
|
|
|
def __init__(self, inplanes, planes, conv_builder, stride=1, downsample=None): |
|
super(Bottleneck, self).__init__() |
|
midplanes = (inplanes * planes * 3 * 3 * 3) // (inplanes * 3 * 3 + 3 * planes) |
|
|
|
|
|
self.conv1 = nn.Sequential( |
|
nn.Conv3d(inplanes, planes, kernel_size=1, bias=False), |
|
batchnorm(planes), |
|
nn.ReLU(inplace=True), |
|
) |
|
|
|
self.conv2 = nn.Sequential( |
|
conv_builder(planes, planes, midplanes, stride), |
|
batchnorm(planes), |
|
nn.ReLU(inplace=True), |
|
) |
|
|
|
|
|
self.conv3 = nn.Sequential( |
|
nn.Conv3d(planes, planes * self.expansion, kernel_size=1, bias=False), |
|
batchnorm(planes * self.expansion), |
|
) |
|
self.relu = nn.ReLU(inplace=True) |
|
self.downsample = downsample |
|
self.stride = stride |
|
|
|
def forward(self, x): |
|
residual = x |
|
|
|
out = self.conv1(x) |
|
out = self.conv2(out) |
|
out = self.conv3(out) |
|
|
|
if self.downsample is not None: |
|
residual = self.downsample(x) |
|
|
|
out += residual |
|
out = self.relu(out) |
|
|
|
return out |
|
|
|
|
|
class BasicStem(nn.Sequential): |
|
"""The default conv-batchnorm-relu stem""" |
|
|
|
def __init__(self, outplanes=32): |
|
super(BasicStem, self).__init__( |
|
nn.Conv3d( |
|
3, |
|
outplanes, |
|
kernel_size=(3, 7, 7), |
|
stride=(1, 2, 2), |
|
padding=(1, 3, 3), |
|
bias=False, |
|
), |
|
batchnorm(outplanes), |
|
nn.ReLU(inplace=True), |
|
) |
|
|
|
|
|
class R2Plus1dStem(nn.Sequential): |
|
"""R(2+1)D stem is different than the default one as it uses separated 3D convolution""" |
|
|
|
def __init__(self): |
|
super(R2Plus1dStem, self).__init__( |
|
nn.Conv3d( |
|
3, |
|
45, |
|
kernel_size=(1, 7, 7), |
|
stride=(1, 2, 2), |
|
padding=(0, 3, 3), |
|
bias=False, |
|
), |
|
batchnorm(45), |
|
nn.ReLU(inplace=True), |
|
nn.Conv3d( |
|
45, |
|
64, |
|
kernel_size=(3, 1, 1), |
|
stride=(1, 1, 1), |
|
padding=(1, 0, 0), |
|
bias=False, |
|
), |
|
batchnorm(64), |
|
nn.ReLU(inplace=True), |
|
) |
|
|
|
|
|
class VideoResNet(nn.Module): |
|
def __init__( |
|
self, |
|
block, |
|
conv_makers, |
|
layers, |
|
stem, |
|
zero_init_residual=False, |
|
channels=[32, 64, 96, 128], |
|
): |
|
"""Generic resnet video generator. |
|
|
|
Args: |
|
block (nn.Module): resnet building block |
|
conv_makers (list(functions)): generator function for each layer |
|
layers (List[int]): number of blocks per layer |
|
stem (nn.Module, optional): Resnet stem, if None, defaults to conv-bn-relu. Defaults to None. |
|
zero_init_residual (bool, optional): Zero init bottleneck residual BN. Defaults to False. |
|
""" |
|
super(VideoResNet, self).__init__() |
|
self.inplanes = channels[0] |
|
|
|
self.stem = stem() |
|
|
|
self.layer1 = self._make_layer( |
|
block, conv_makers[0], channels[0], layers[0], stride=1 |
|
) |
|
self.layer2 = self._make_layer( |
|
block, conv_makers[1], channels[1], layers[1], stride=2, temporal_stride=1 |
|
) |
|
self.layer3 = self._make_layer( |
|
block, conv_makers[2], channels[2], layers[2], stride=2, temporal_stride=1 |
|
) |
|
self.layer4 = self._make_layer( |
|
block, conv_makers[3], channels[3], layers[3], stride=1, temporal_stride=1 |
|
) |
|
|
|
|
|
self._initialize_weights() |
|
|
|
if zero_init_residual: |
|
for m in self.modules(): |
|
if isinstance(m, Bottleneck): |
|
nn.init.constant_(m.bn3.weight, 0) |
|
|
|
def forward(self, x): |
|
tensorConv0 = self.stem(x) |
|
tensorConv1 = self.layer1(tensorConv0) |
|
tensorConv2 = self.layer2(tensorConv1) |
|
tensorConv3 = self.layer3(tensorConv2) |
|
tensorConv4 = self.layer4(tensorConv3) |
|
return tensorConv0, tensorConv1, tensorConv2, tensorConv3, tensorConv4 |
|
|
|
def _make_layer( |
|
self, block, conv_builder, planes, blocks, stride=1, temporal_stride=None |
|
): |
|
downsample = None |
|
|
|
if stride != 1 or self.inplanes != planes * block.expansion: |
|
ds_stride = conv_builder.get_downsample_stride(stride, temporal_stride) |
|
downsample = nn.Sequential( |
|
nn.Conv3d( |
|
self.inplanes, |
|
planes * block.expansion, |
|
kernel_size=1, |
|
stride=ds_stride, |
|
bias=False, |
|
), |
|
batchnorm(planes * block.expansion), |
|
) |
|
stride = ds_stride |
|
|
|
layers = [] |
|
layers.append(block(self.inplanes, planes, conv_builder, stride, downsample)) |
|
|
|
self.inplanes = planes * block.expansion |
|
for i in range(1, blocks): |
|
layers.append(block(self.inplanes, planes, conv_builder)) |
|
|
|
return nn.Sequential(*layers) |
|
|
|
def _initialize_weights(self): |
|
for m in self.modules(): |
|
if isinstance(m, nn.Conv3d): |
|
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu") |
|
if m.bias is not None: |
|
nn.init.constant_(m.bias, 0) |
|
elif isinstance(m, nn.BatchNorm3d): |
|
nn.init.constant_(m.weight, 1) |
|
nn.init.constant_(m.bias, 0) |
|
elif isinstance(m, nn.Linear): |
|
nn.init.normal_(m.weight, 0, 0.01) |
|
nn.init.constant_(m.bias, 0) |
|
|
|
|
|
def _video_resnet(arch, pretrained=False, progress=True, **kwargs): |
|
model = VideoResNet(**kwargs) |
|
|
|
if pretrained: |
|
state_dict = load_state_dict_from_url(model_urls[arch], progress=progress, model_dir=get_ckpt_container_path(MODEL_TYPE)) |
|
model.load_state_dict(state_dict) |
|
return model |
|
|
|
|
|
def r3d_18(bn=False, pretrained=False, progress=True, **kwargs): |
|
"""Construct 18 layer Resnet3D model as in |
|
https://arxiv.org/abs/1711.11248 |
|
|
|
Args: |
|
pretrained (bool): If True, returns a model pre-trained on Kinetics-400 |
|
progress (bool): If True, displays a progress bar of the download to stderr |
|
|
|
Returns: |
|
nn.Module: R3D-18 network |
|
""" |
|
|
|
global batchnorm |
|
if bn: |
|
batchnorm = nn.BatchNorm3d |
|
else: |
|
batchnorm = identity |
|
|
|
return _video_resnet( |
|
"r3d_18", |
|
pretrained, |
|
progress, |
|
block=BasicBlock, |
|
conv_makers=[Conv3DSimple] * 4, |
|
layers=[2, 2, 2, 2], |
|
stem=BasicStem, |
|
**kwargs, |
|
) |
|
|
|
|
|
def mc3_18(bn=False, pretrained=False, progress=True, **kwargs): |
|
"""Constructor for 18 layer Mixed Convolution network as in |
|
https://arxiv.org/abs/1711.11248 |
|
|
|
Args: |
|
pretrained (bool): If True, returns a model pre-trained on Kinetics-400 |
|
progress (bool): If True, displays a progress bar of the download to stderr |
|
|
|
Returns: |
|
nn.Module: MC3 Network definition |
|
""" |
|
global batchnorm |
|
if bn: |
|
batchnorm = nn.BatchNorm3d |
|
else: |
|
batchnorm = identity |
|
|
|
return _video_resnet( |
|
"mc3_18", |
|
pretrained, |
|
progress, |
|
block=BasicBlock, |
|
conv_makers=[Conv3DSimple] + [Conv3DNoTemporal] * 3, |
|
layers=[2, 2, 2, 2], |
|
stem=BasicStem, |
|
**kwargs, |
|
) |
|
|
|
|
|
def r2plus1d_18(bn=False, pretrained=False, progress=True, **kwargs): |
|
"""Constructor for the 18 layer deep R(2+1)D network as in |
|
https://arxiv.org/abs/1711.11248 |
|
|
|
Args: |
|
pretrained (bool): If True, returns a model pre-trained on Kinetics-400 |
|
progress (bool): If True, displays a progress bar of the download to stderr |
|
|
|
Returns: |
|
nn.Module: R(2+1)D-18 network |
|
""" |
|
|
|
global batchnorm |
|
if bn: |
|
batchnorm = nn.BatchNorm3d |
|
else: |
|
batchnorm = identity |
|
|
|
return _video_resnet( |
|
"r2plus1d_18", |
|
pretrained, |
|
progress, |
|
block=BasicBlock, |
|
conv_makers=[Conv2Plus1D] * 4, |
|
layers=[2, 2, 2, 2], |
|
stem=R2Plus1dStem, |
|
**kwargs, |
|
) |
|
|
|
|
|
class upConv3D(nn.Module): |
|
def __init__(self, in_ch, out_ch, kernel_size, stride, padding, upmode="transpose"): |
|
super().__init__() |
|
self.upmode = upmode |
|
if self.upmode == "transpose": |
|
self.upconv = nn.ModuleList( |
|
[ |
|
nn.ConvTranspose3d( |
|
in_ch, |
|
out_ch, |
|
kernel_size=kernel_size, |
|
stride=stride, |
|
padding=padding, |
|
), |
|
SEGating(out_ch), |
|
batchnorm(out_ch), |
|
] |
|
) |
|
else: |
|
self.upconv = nn.ModuleList( |
|
[ |
|
nn.Upsample( |
|
mode="trilinear", scale_factor=(1, 2, 2), align_corners=False |
|
), |
|
nn.Conv3d(in_ch, out_ch, kernel_size=1, stride=1), |
|
SEGating(out_ch), |
|
batchnorm(out_ch), |
|
] |
|
) |
|
self.upconv = nn.Sequential(*self.upconv) |
|
|
|
def forward(self, x): |
|
return self.upconv(x) |
|
|
|
|
|
class Conv_3d(nn.Module): |
|
def __init__(self, in_ch, out_ch, kernel_size, stride=1, padding=0, bias=True): |
|
super().__init__() |
|
self.conv = nn.Sequential( |
|
nn.Conv3d( |
|
in_ch, |
|
out_ch, |
|
kernel_size=kernel_size, |
|
stride=stride, |
|
padding=padding, |
|
bias=bias, |
|
), |
|
SEGating(out_ch), |
|
batchnorm(out_ch), |
|
) |
|
|
|
def forward(self, x): |
|
return self.conv(x) |
|
|
|
|
|
def make_optimizer(args, my_model): |
|
trainable = filter(lambda x: x.requires_grad, my_model.parameters()) |
|
|
|
if args.optimizer == "SGD": |
|
optimizer_function = optim.SGD |
|
kwargs = {"momentum": 0.9} |
|
elif args.optimizer == "ADAM": |
|
optimizer_function = optim.Adam |
|
kwargs = {"betas": (0.9, 0.999), "eps": 1e-08} |
|
elif args.optimizer == "ADAMax": |
|
optimizer_function = optim.Adamax |
|
kwargs = {"betas": (0.9, 0.999), "eps": 1e-08} |
|
elif args.optimizer == "RMSprop": |
|
optimizer_function = optim.RMSprop |
|
kwargs = {"eps": 1e-08} |
|
|
|
kwargs["lr"] = args.lr |
|
kwargs["weight_decay"] = args.weight_decay |
|
|
|
return optimizer_function(trainable, **kwargs) |
|
|
|
|
|
def make_scheduler(args, my_optimizer): |
|
if args.decay_type == "step": |
|
scheduler = lrs.StepLR(my_optimizer, step_size=args.lr_decay, gamma=args.gamma) |
|
elif args.decay_type.find("step") >= 0: |
|
milestones = args.decay_type.split("_") |
|
milestones.pop(0) |
|
milestones = list(map(lambda x: int(x), milestones)) |
|
scheduler = lrs.MultiStepLR( |
|
my_optimizer, milestones=milestones, gamma=args.gamma |
|
) |
|
elif args.decay_type == "plateau": |
|
scheduler = lrs.ReduceLROnPlateau( |
|
my_optimizer, |
|
mode="max", |
|
factor=args.gamma, |
|
patience=args.patience, |
|
threshold=0.01, |
|
threshold_mode="abs", |
|
verbose=True, |
|
) |
|
|
|
return scheduler |
|
|
|
|
|
def gaussian_kernel(sz, sigma): |
|
k = torch.arange(-(sz - 1) / 2, (sz + 1) / 2) |
|
k = torch.exp(-1.0 / (2 * sigma**2) * k**2) |
|
k = k.reshape(-1, 1) * k.reshape(1, -1) |
|
k = k / torch.sum(k) |
|
return k |
|
|
|
|
|
def moduleNormalize(frame): |
|
return torch.cat( |
|
[ |
|
(frame[:, 0:1, :, :] - 0.4631), |
|
(frame[:, 1:2, :, :] - 0.4352), |
|
(frame[:, 2:3, :, :] - 0.3990), |
|
], |
|
1, |
|
) |
|
|
|
|
|
class FoldUnfold: |
|
""" |
|
Class to handle folding tensor frame into batch of patches and back to frame again |
|
Thanks to Charlie Tan ([email protected]) for the earier version. |
|
""" |
|
|
|
def __init__(self, height, width, patch_size, overlap): |
|
if height % 2 or width % 2 or patch_size % 2 or overlap % 2: |
|
print( |
|
"only defined for even values of height, width, patch_size size and overlap, odd values will reconstruct incorrectly" |
|
) |
|
return |
|
|
|
self.height = height |
|
self.width = width |
|
|
|
self.patch_size = patch_size |
|
self.overlap = overlap |
|
self.stride = patch_size - overlap |
|
|
|
def fold_to_patches(self, *frames): |
|
""" |
|
args: frames -- list of (1,3,H,W) tensors |
|
returns: list of (B,3,h,w) image patches |
|
""" |
|
|
|
|
|
n_blocks_h = (self.height // (self.stride)) + 1 |
|
n_blocks_w = (self.width // (self.stride)) + 1 |
|
|
|
|
|
self.pad_h = (self.stride * n_blocks_h + self.overlap - self.height) // 2 |
|
self.pad_w = (self.stride * n_blocks_w + self.overlap - self.width) // 2 |
|
self.height_pad = self.height + 2 * self.pad_h |
|
self.width_pad = self.width + 2 * self.pad_w |
|
|
|
|
|
patches_list = [] |
|
for i in range(len(frames)): |
|
padded = F.pad( |
|
frames[i], |
|
(self.pad_w, self.pad_w, self.pad_h, self.pad_h), |
|
mode="reflect", |
|
) |
|
unfolded = F.unfold(padded, self.patch_size, stride=self.stride) |
|
patches = unfolded.permute(2, 1, 0).reshape( |
|
-1, 3, self.patch_size, self.patch_size |
|
) |
|
patches_list.append(patches) |
|
|
|
return patches_list |
|
|
|
def unfold_to_frame(self, patches): |
|
""" |
|
args: patches -- tensor of shape (B,3,h,w) |
|
returns: frame -- tensor of shape (1,3,H,W) |
|
""" |
|
|
|
|
|
frame_unfold = patches.reshape(-1, 3 * self.patch_size**2, 1).permute(2, 1, 0) |
|
|
|
|
|
frame_fold = F.fold( |
|
frame_unfold, |
|
(self.height_pad, self.width_pad), |
|
self.patch_size, |
|
stride=self.stride, |
|
) |
|
|
|
|
|
|
|
ones = torch.ones_like(frame_fold) |
|
ones_unfold = F.unfold(ones, self.patch_size, stride=self.stride) |
|
|
|
|
|
|
|
divisor = F.fold( |
|
ones_unfold, |
|
(self.height_pad, self.width_pad), |
|
self.patch_size, |
|
stride=self.stride, |
|
) |
|
|
|
|
|
frame_div = frame_fold / divisor |
|
|
|
|
|
frame_crop = frame_div[ |
|
:, :, self.pad_h : -self.pad_h, self.pad_w : -self.pad_w |
|
].clone() |
|
|
|
return frame_crop |
|
|
|
|
|
def read_frame_yuv2rgb(stream, width, height, iFrame, bit_depth, pix_fmt="420"): |
|
if pix_fmt == "420": |
|
multiplier = 1 |
|
uv_factor = 2 |
|
elif pix_fmt == "444": |
|
multiplier = 2 |
|
uv_factor = 1 |
|
else: |
|
print("Pixel format {} is not supported".format(pix_fmt)) |
|
return |
|
|
|
if bit_depth == 8: |
|
datatype = np.uint8 |
|
stream.seek(iFrame * 1.5 * width * height * multiplier) |
|
Y = np.fromfile(stream, dtype=datatype, count=width * height).reshape( |
|
(height, width) |
|
) |
|
|
|
|
|
U = np.fromfile( |
|
stream, dtype=datatype, count=(width // uv_factor) * (height // uv_factor) |
|
).reshape((height // uv_factor, width // uv_factor)) |
|
V = np.fromfile( |
|
stream, dtype=datatype, count=(width // uv_factor) * (height // uv_factor) |
|
).reshape((height // uv_factor, width // uv_factor)) |
|
|
|
else: |
|
datatype = np.uint16 |
|
stream.seek(iFrame * 3 * width * height * multiplier) |
|
Y = np.fromfile(stream, dtype=datatype, count=width * height).reshape( |
|
(height, width) |
|
) |
|
|
|
U = np.fromfile( |
|
stream, dtype=datatype, count=(width // uv_factor) * (height // uv_factor) |
|
).reshape((height // uv_factor, width // uv_factor)) |
|
V = np.fromfile( |
|
stream, dtype=datatype, count=(width // uv_factor) * (height // uv_factor) |
|
).reshape((height // uv_factor, width // uv_factor)) |
|
|
|
if pix_fmt == "420": |
|
yuv = np.empty((height * 3 // 2, width), dtype=datatype) |
|
yuv[0:height, :] = Y |
|
|
|
yuv[height : height + height // 4, :] = U.reshape(-1, width) |
|
yuv[height + height // 4 :, :] = V.reshape(-1, width) |
|
|
|
if bit_depth != 8: |
|
yuv = (yuv / (2**bit_depth - 1) * 255).astype(np.uint8) |
|
|
|
|
|
rgb = cv2.cvtColor(yuv, cv2.COLOR_YUV2RGB_I420) |
|
|
|
else: |
|
yvu = np.stack([Y, V, U], axis=2) |
|
if bit_depth != 8: |
|
yvu = (yvu / (2**bit_depth - 1) * 255).astype(np.uint8) |
|
rgb = cv2.cvtColor(yvu, cv2.COLOR_YCrCb2RGB) |
|
|
|
return rgb |
|
|
|
|
|
def quantize(imTensor): |
|
return imTensor.clamp(0.0, 1.0).mul(255).round() |
|
|
|
|
|
def tensor2rgb(tensor): |
|
""" |
|
Convert GPU Tensor to RGB image (numpy array) |
|
""" |
|
out = [] |
|
for b in range(tensor.shape[0]): |
|
out.append( |
|
np.moveaxis(quantize(tensor[b]).cpu().detach().numpy(), 0, 2).astype( |
|
np.uint8 |
|
) |
|
) |
|
return np.array(out) |
|
|
|
|
|
class Identity(nn.Module): |
|
def __init__(self, *args): |
|
super(Identity, self).__init__() |
|
|
|
def forward(self, x): |
|
return x |
|
|
|
|
|
class SEBlock(nn.Module): |
|
def __init__(self, input_dim, reduction=16): |
|
super(SEBlock, self).__init__() |
|
mid = int(input_dim / reduction) |
|
self.avg_pool = nn.AdaptiveAvgPool2d(1) |
|
self.fc = nn.Sequential( |
|
nn.Linear(input_dim, mid), |
|
nn.ReLU(inplace=True), |
|
nn.Linear(mid, input_dim), |
|
nn.Sigmoid(), |
|
) |
|
|
|
def forward(self, x): |
|
b, c, _, _ = x.size() |
|
y = self.avg_pool(x).view(b, c) |
|
y = self.fc(y).view(b, c, 1, 1) |
|
return x * y |
|
|
|
|
|
class ResNextBlock(nn.Module): |
|
def __init__( |
|
self, down, cin, cout, ks, stride=1, groups=32, base_width=4, norm_layer=None |
|
): |
|
super(ResNextBlock, self).__init__() |
|
if norm_layer is None or norm_layer == "batch": |
|
norm_layer = nn.BatchNorm2d |
|
elif norm_layer == "identity": |
|
norm_layer = Identity |
|
width = int(cout * (base_width / 64.0)) * groups |
|
|
|
self.conv1 = nn.Conv2d(cin, width, kernel_size=1, stride=1, bias=False) |
|
self.bn1 = norm_layer(width) |
|
if down: |
|
self.conv2 = nn.Conv2d( |
|
width, |
|
width, |
|
kernel_size=ks, |
|
stride=stride, |
|
padding=(ks - 1) // 2, |
|
groups=groups, |
|
bias=False, |
|
) |
|
else: |
|
self.conv2 = nn.ConvTranspose2d( |
|
width, |
|
width, |
|
kernel_size=ks, |
|
stride=stride, |
|
padding=(ks - stride) // 2, |
|
groups=groups, |
|
bias=False, |
|
) |
|
self.bn2 = norm_layer(width) |
|
self.conv3 = nn.Conv2d(width, cout, kernel_size=1, stride=1, bias=False) |
|
self.bn3 = norm_layer(cout) |
|
self.relu = nn.ReLU(inplace=True) |
|
self.downsample = None |
|
if stride != 1 or cin != cout: |
|
if down: |
|
self.downsample = nn.Sequential( |
|
nn.Conv2d(cin, cout, kernel_size=1, stride=stride, bias=False), |
|
norm_layer(cout), |
|
) |
|
else: |
|
self.downsample = nn.Sequential( |
|
|
|
nn.ConvTranspose2d( |
|
cin, cout, kernel_size=2, stride=stride, bias=False |
|
), |
|
norm_layer(cout), |
|
) |
|
self.stride = stride |
|
|
|
def forward(self, x): |
|
identity = x |
|
|
|
out = self.conv1(x) |
|
out = self.bn1(out) |
|
out = self.relu(out) |
|
|
|
out = self.conv2(out) |
|
out = self.bn2(out) |
|
out = self.relu(out) |
|
|
|
out = self.conv3(out) |
|
out = self.bn3(out) |
|
|
|
if self.downsample is not None: |
|
identity = self.downsample(x) |
|
|
|
out += identity |
|
out = self.relu(out) |
|
|
|
return out |
|
|
|
|
|
class MultiScaleResNextBlock(nn.Module): |
|
def __init__(self, down, cin, cout, ks_s, ks_l, stride, norm_layer): |
|
super(MultiScaleResNextBlock, self).__init__() |
|
self.resnext_small = ResNextBlock( |
|
down, cin, cout // 2, ks_s, stride, norm_layer=norm_layer |
|
) |
|
self.resnext_large = ResNextBlock( |
|
down, cin, cout // 2, ks_l, stride, norm_layer=norm_layer |
|
) |
|
self.attention = SEBlock(cout) |
|
|
|
def forward(self, tensorCombine): |
|
out_small = self.resnext_small(tensorCombine) |
|
out_large = self.resnext_large(tensorCombine) |
|
out = torch.cat([out_small, out_large], 1) |
|
out = self.attention(out) |
|
return out |
|
|
|
|
|
class UMultiScaleResNext(nn.Module): |
|
def __init__( |
|
self, channels=[64, 128, 256, 512], norm_layer="batch", inplanes=6, **kwargs |
|
): |
|
super(UMultiScaleResNext, self).__init__() |
|
self.conv1 = MultiScaleResNextBlock( |
|
True, inplanes, channels[0], ks_s=3, ks_l=7, stride=2, norm_layer=norm_layer |
|
) |
|
self.conv2 = MultiScaleResNextBlock( |
|
True, |
|
channels[0], |
|
channels[1], |
|
ks_s=3, |
|
ks_l=7, |
|
stride=2, |
|
norm_layer=norm_layer, |
|
) |
|
self.conv3 = MultiScaleResNextBlock( |
|
True, |
|
channels[1], |
|
channels[2], |
|
ks_s=3, |
|
ks_l=5, |
|
stride=2, |
|
norm_layer=norm_layer, |
|
) |
|
self.conv4 = MultiScaleResNextBlock( |
|
True, |
|
channels[2], |
|
channels[3], |
|
ks_s=3, |
|
ks_l=5, |
|
stride=2, |
|
norm_layer=norm_layer, |
|
) |
|
|
|
self.deconv4 = MultiScaleResNextBlock( |
|
True, |
|
channels[3], |
|
channels[3], |
|
ks_s=3, |
|
ks_l=5, |
|
stride=1, |
|
norm_layer=norm_layer, |
|
) |
|
self.deconv3 = MultiScaleResNextBlock( |
|
False, |
|
channels[3], |
|
channels[2], |
|
ks_s=4, |
|
ks_l=6, |
|
stride=2, |
|
norm_layer=norm_layer, |
|
) |
|
self.deconv2 = MultiScaleResNextBlock( |
|
False, |
|
channels[2], |
|
channels[1], |
|
ks_s=4, |
|
ks_l=8, |
|
stride=2, |
|
norm_layer=norm_layer, |
|
) |
|
self.deconv1 = MultiScaleResNextBlock( |
|
False, |
|
channels[1], |
|
channels[0], |
|
ks_s=4, |
|
ks_l=8, |
|
stride=2, |
|
norm_layer=norm_layer, |
|
) |
|
|
|
def forward(self, im0, im2): |
|
tensorJoin = torch.cat([im0, im2], 1) |
|
|
|
tensorConv1 = self.conv1(tensorJoin) |
|
tensorConv2 = self.conv2(tensorConv1) |
|
tensorConv3 = self.conv3(tensorConv2) |
|
tensorConv4 = self.conv4(tensorConv3) |
|
|
|
tensorDeconv4 = self.deconv4(tensorConv4) |
|
tensorDeconv3 = self.deconv3(tensorDeconv4 + tensorConv4) |
|
tensorDeconv2 = self.deconv2(tensorDeconv3 + tensorConv3) |
|
tensorDeconv1 = self.deconv1(tensorDeconv2 + tensorConv2) |
|
|
|
return tensorDeconv1 |
|
|
|
|
|
class MultiInputGridNet(nn.Module): |
|
def __init__(self, in_chs, out_chs, grid_chs=(32, 64, 96), n_row=3, n_col=6): |
|
super(MultiInputGridNet, self).__init__() |
|
|
|
self.n_row = n_row |
|
self.n_col = n_col |
|
self.n_chs = grid_chs |
|
assert ( |
|
len(grid_chs) == self.n_row |
|
), "should give num channels for each row (scale stream)" |
|
assert ( |
|
len(in_chs) == self.n_row |
|
), "should give input channels for each row (scale stream)" |
|
|
|
for r, n_ch in enumerate(self.n_chs): |
|
setattr(self, f"lateral_{r}_0", LateralBlock(in_chs[r], n_ch)) |
|
for c in range(1, self.n_col): |
|
setattr(self, f"lateral_{r}_{c}", LateralBlock(n_ch, n_ch)) |
|
|
|
for r, (in_ch, out_ch) in enumerate(zip(self.n_chs[:-1], self.n_chs[1:])): |
|
for c in range(int(self.n_col / 2)): |
|
setattr(self, f"down_{r}_{c}", DownSamplingBlock(in_ch, out_ch)) |
|
|
|
for r, (in_ch, out_ch) in enumerate(zip(self.n_chs[1:], self.n_chs[:-1])): |
|
for c in range(int(self.n_col / 2)): |
|
setattr(self, f"up_{r}_{c}", UpSamplingBlock(in_ch, out_ch)) |
|
|
|
self.lateral_final = LateralBlock(self.n_chs[0], out_chs) |
|
|
|
def forward(self, *args): |
|
assert len(args) == self.n_row |
|
|
|
|
|
cur_col = list(args) |
|
for c in range(int(self.n_col / 2)): |
|
for r in range(self.n_row): |
|
cur_col[r] = getattr(self, f"lateral_{r}_{c}")(cur_col[r]) |
|
if r != 0: |
|
cur_col[r] += getattr(self, f"down_{r-1}_{c}")(cur_col[r - 1]) |
|
|
|
for c in range(int(self.n_col / 2), self.n_col): |
|
for r in range(self.n_row - 1, -1, -1): |
|
cur_col[r] = getattr(self, f"lateral_{r}_{c}")(cur_col[r]) |
|
if r != self.n_row - 1: |
|
cur_col[r] += getattr(self, f"up_{r}_{c-int(self.n_col/2)}")( |
|
cur_col[r + 1] |
|
) |
|
|
|
return self.lateral_final(cur_col[0]) |
|
|
|
|
|
class MIMOGridNet(nn.Module): |
|
def __init__( |
|
self, in_chs, out_chs, grid_chs=(32, 64, 96), n_row=3, n_col=6, outrow=(0, 1, 2) |
|
): |
|
super(MIMOGridNet, self).__init__() |
|
|
|
self.n_row = n_row |
|
self.n_col = n_col |
|
self.n_chs = grid_chs |
|
self.outrow = outrow |
|
assert ( |
|
len(grid_chs) == self.n_row |
|
), "should give num channels for each row (scale stream)" |
|
assert ( |
|
len(in_chs) == self.n_row |
|
), "should give input channels for each row (scale stream)" |
|
assert len(out_chs) == len( |
|
self.outrow |
|
), "should give out channels for each output row (scale stream)" |
|
|
|
for r, n_ch in enumerate(self.n_chs): |
|
setattr(self, f"lateral_{r}_0", LateralBlock(in_chs[r], n_ch)) |
|
for c in range(1, self.n_col): |
|
setattr(self, f"lateral_{r}_{c}", LateralBlock(n_ch, n_ch)) |
|
|
|
for r, (in_ch, out_ch) in enumerate(zip(self.n_chs[:-1], self.n_chs[1:])): |
|
for c in range(int(self.n_col / 2)): |
|
setattr(self, f"down_{r}_{c}", DownSamplingBlock(in_ch, out_ch)) |
|
|
|
for r, (in_ch, out_ch) in enumerate(zip(self.n_chs[1:], self.n_chs[:-1])): |
|
for c in range(int(self.n_col / 2)): |
|
setattr(self, f"up_{r}_{c}", UpSamplingBlock(in_ch, out_ch)) |
|
|
|
for i, r in enumerate(outrow): |
|
setattr(self, f"lateral_final_{r}", LateralBlock(self.n_chs[r], out_chs[i])) |
|
|
|
def forward(self, *args): |
|
assert len(args) == self.n_row |
|
|
|
|
|
cur_col = list(args) |
|
for c in range(int(self.n_col / 2)): |
|
for r in range(self.n_row): |
|
cur_col[r] = getattr(self, f"lateral_{r}_{c}")(cur_col[r]) |
|
if r != 0: |
|
cur_col[r] += getattr(self, f"down_{r-1}_{c}")(cur_col[r - 1]) |
|
|
|
for c in range(int(self.n_col / 2), self.n_col): |
|
for r in range(self.n_row - 1, -1, -1): |
|
cur_col[r] = getattr(self, f"lateral_{r}_{c}")(cur_col[r]) |
|
if r != self.n_row - 1: |
|
cur_col[r] += getattr(self, f"up_{r}_{c-int(self.n_col/2)}")( |
|
cur_col[r + 1] |
|
) |
|
|
|
out = [] |
|
for r in self.outrow: |
|
out.append(getattr(self, f"lateral_final_{r}")(cur_col[r])) |
|
|
|
return out |
|
|
|
|
|
class GeneralGridNet(nn.Module): |
|
def __init__(self, in_chs, out_chs, grid_chs=(32, 64, 96), n_row=3, n_col=6): |
|
super(GeneralGridNet, self).__init__() |
|
|
|
self.n_row = n_row |
|
self.n_col = n_col |
|
self.n_chs = grid_chs |
|
assert ( |
|
len(grid_chs) == self.n_row |
|
), "should give num channels for each row (scale stream)" |
|
|
|
for r, n_ch in enumerate(self.n_chs): |
|
if r == 0: |
|
setattr(self, f"lateral_{r}_0", LateralBlock(in_chs, n_ch)) |
|
for c in range(1, self.n_col): |
|
setattr(self, f"lateral_{r}_{c}", LateralBlock(n_ch, n_ch)) |
|
|
|
for r, (in_ch, out_ch) in enumerate(zip(self.n_chs[:-1], self.n_chs[1:])): |
|
for c in range(int(self.n_col / 2)): |
|
setattr(self, f"down_{r}_{c}", DownSamplingBlock(in_ch, out_ch)) |
|
|
|
for r, (in_ch, out_ch) in enumerate(zip(self.n_chs[1:], self.n_chs[:-1])): |
|
for c in range(int(self.n_col / 2)): |
|
setattr(self, f"up_{r}_{c}", UpSamplingBlock(in_ch, out_ch)) |
|
|
|
self.lateral_final = LateralBlock(self.n_chs[0], out_chs) |
|
|
|
def forward(self, x): |
|
cur_col = [x] + [None] * (self.n_row - 1) |
|
for c in range(int(self.n_col / 2)): |
|
for r in range(self.n_row): |
|
if cur_col[r] != None: |
|
cur_col[r] = getattr(self, f"lateral_{r}_{c}")(cur_col[r]) |
|
else: |
|
cur_col[r] = 0.0 |
|
if r != 0: |
|
cur_col[r] += getattr(self, f"down_{r-1}_{c}")(cur_col[r - 1]) |
|
|
|
for c in range(int(self.n_col / 2), self.n_col): |
|
for r in range(self.n_row - 1, -1, -1): |
|
cur_col[r] = getattr(self, f"lateral_{r}_{c}")(cur_col[r]) |
|
if r != self.n_row - 1: |
|
cur_col[r] += getattr(self, f"up_{r}_{c-int(self.n_col/2)}")( |
|
cur_col[r + 1] |
|
) |
|
|
|
return self.lateral_final(cur_col[0]) |
|
|
|
|
|
class GridNet(nn.Module): |
|
def __init__(self, in_chs, out_chs, grid_chs=(32, 64, 96)): |
|
super(GridNet, self).__init__() |
|
|
|
self.n_row = 3 |
|
self.n_col = 6 |
|
self.n_chs = grid_chs |
|
assert ( |
|
len(grid_chs) == self.n_row |
|
), "should give num channels for each row (scale stream)" |
|
|
|
self.lateral_init = LateralBlock(in_chs, self.n_chs[0]) |
|
|
|
for r, n_ch in enumerate(self.n_chs): |
|
for c in range(self.n_col - 1): |
|
setattr(self, f"lateral_{r}_{c}", LateralBlock(n_ch, n_ch)) |
|
|
|
for r, (in_ch, out_ch) in enumerate(zip(self.n_chs[:-1], self.n_chs[1:])): |
|
for c in range(int(self.n_col / 2)): |
|
setattr(self, f"down_{r}_{c}", DownSamplingBlock(in_ch, out_ch)) |
|
|
|
for r, (in_ch, out_ch) in enumerate(zip(self.n_chs[1:], self.n_chs[:-1])): |
|
for c in range(int(self.n_col / 2)): |
|
setattr(self, f"up_{r}_{c}", UpSamplingBlock(in_ch, out_ch)) |
|
|
|
self.lateral_final = LateralBlock(self.n_chs[0], out_chs) |
|
|
|
def forward(self, x): |
|
state_00 = self.lateral_init(x) |
|
state_10 = self.down_0_0(state_00) |
|
state_20 = self.down_1_0(state_10) |
|
|
|
state_01 = self.lateral_0_0(state_00) |
|
state_11 = self.down_0_1(state_01) + self.lateral_1_0(state_10) |
|
state_21 = self.down_1_1(state_11) + self.lateral_2_0(state_20) |
|
|
|
state_02 = self.lateral_0_1(state_01) |
|
state_12 = self.down_0_2(state_02) + self.lateral_1_1(state_11) |
|
state_22 = self.down_1_2(state_12) + self.lateral_2_1(state_21) |
|
|
|
state_23 = self.lateral_2_2(state_22) |
|
state_13 = self.up_1_0(state_23) + self.lateral_1_2(state_12) |
|
state_03 = self.up_0_0(state_13) + self.lateral_0_2(state_02) |
|
|
|
state_24 = self.lateral_2_3(state_23) |
|
state_14 = self.up_1_1(state_24) + self.lateral_1_3(state_13) |
|
state_04 = self.up_0_1(state_14) + self.lateral_0_3(state_03) |
|
|
|
state_25 = self.lateral_2_4(state_24) |
|
state_15 = self.up_1_2(state_25) + self.lateral_1_4(state_14) |
|
state_05 = self.up_0_2(state_15) + self.lateral_0_4(state_04) |
|
|
|
return self.lateral_final(state_05) |
|
|
|
|
|
class LateralBlock(nn.Module): |
|
def __init__(self, ch_in, ch_out): |
|
super(LateralBlock, self).__init__() |
|
self.f = nn.Sequential( |
|
nn.PReLU(), |
|
nn.Conv2d(ch_in, ch_out, kernel_size=3, padding=1), |
|
nn.PReLU(), |
|
nn.Conv2d(ch_out, ch_out, kernel_size=3, padding=1), |
|
) |
|
if ch_in != ch_out: |
|
self.conv = nn.Conv2d(ch_in, ch_out, kernel_size=3, padding=1) |
|
|
|
def forward(self, x): |
|
fx = self.f(x) |
|
if fx.shape[1] != x.shape[1]: |
|
x = self.conv(x) |
|
return fx + x |
|
|
|
|
|
class DownSamplingBlock(nn.Module): |
|
def __init__(self, ch_in, ch_out): |
|
super(DownSamplingBlock, self).__init__() |
|
self.f = nn.Sequential( |
|
nn.PReLU(), |
|
nn.Conv2d(ch_in, ch_out, kernel_size=3, stride=2, padding=1), |
|
nn.PReLU(), |
|
nn.Conv2d(ch_out, ch_out, kernel_size=3, padding=1), |
|
) |
|
|
|
def forward(self, x): |
|
return self.f(x) |
|
|
|
|
|
class UpSamplingBlock(nn.Module): |
|
def __init__(self, ch_in, ch_out): |
|
super(UpSamplingBlock, self).__init__() |
|
self.f = nn.Sequential( |
|
nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False), |
|
nn.PReLU(), |
|
nn.Conv2d(ch_in, ch_out, kernel_size=3, padding=1), |
|
nn.PReLU(), |
|
nn.Conv2d(ch_out, ch_out, kernel_size=3, padding=1), |
|
) |
|
|
|
def forward(self, x): |
|
return self.f(x) |
|
|
|
|
|
|
|
|
|
class Network(torch.nn.Module): |
|
def __init__(self): |
|
super(Network, self).__init__() |
|
|
|
class Extractor(torch.nn.Module): |
|
def __init__(self): |
|
super(Extractor, self).__init__() |
|
|
|
self.netOne = torch.nn.Sequential( |
|
torch.nn.Conv2d( |
|
in_channels=3, |
|
out_channels=16, |
|
kernel_size=3, |
|
stride=2, |
|
padding=1, |
|
), |
|
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
|
torch.nn.Conv2d( |
|
in_channels=16, |
|
out_channels=16, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1, |
|
), |
|
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
|
torch.nn.Conv2d( |
|
in_channels=16, |
|
out_channels=16, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1, |
|
), |
|
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
|
) |
|
|
|
self.netTwo = torch.nn.Sequential( |
|
torch.nn.Conv2d( |
|
in_channels=16, |
|
out_channels=32, |
|
kernel_size=3, |
|
stride=2, |
|
padding=1, |
|
), |
|
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
|
torch.nn.Conv2d( |
|
in_channels=32, |
|
out_channels=32, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1, |
|
), |
|
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
|
torch.nn.Conv2d( |
|
in_channels=32, |
|
out_channels=32, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1, |
|
), |
|
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
|
) |
|
|
|
self.netThr = torch.nn.Sequential( |
|
torch.nn.Conv2d( |
|
in_channels=32, |
|
out_channels=64, |
|
kernel_size=3, |
|
stride=2, |
|
padding=1, |
|
), |
|
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
|
torch.nn.Conv2d( |
|
in_channels=64, |
|
out_channels=64, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1, |
|
), |
|
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
|
torch.nn.Conv2d( |
|
in_channels=64, |
|
out_channels=64, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1, |
|
), |
|
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
|
) |
|
|
|
self.netFou = torch.nn.Sequential( |
|
torch.nn.Conv2d( |
|
in_channels=64, |
|
out_channels=96, |
|
kernel_size=3, |
|
stride=2, |
|
padding=1, |
|
), |
|
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
|
torch.nn.Conv2d( |
|
in_channels=96, |
|
out_channels=96, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1, |
|
), |
|
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
|
torch.nn.Conv2d( |
|
in_channels=96, |
|
out_channels=96, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1, |
|
), |
|
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
|
) |
|
|
|
self.netFiv = torch.nn.Sequential( |
|
torch.nn.Conv2d( |
|
in_channels=96, |
|
out_channels=128, |
|
kernel_size=3, |
|
stride=2, |
|
padding=1, |
|
), |
|
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
|
torch.nn.Conv2d( |
|
in_channels=128, |
|
out_channels=128, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1, |
|
), |
|
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
|
torch.nn.Conv2d( |
|
in_channels=128, |
|
out_channels=128, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1, |
|
), |
|
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
|
) |
|
|
|
self.netSix = torch.nn.Sequential( |
|
torch.nn.Conv2d( |
|
in_channels=128, |
|
out_channels=196, |
|
kernel_size=3, |
|
stride=2, |
|
padding=1, |
|
), |
|
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
|
torch.nn.Conv2d( |
|
in_channels=196, |
|
out_channels=196, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1, |
|
), |
|
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
|
torch.nn.Conv2d( |
|
in_channels=196, |
|
out_channels=196, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1, |
|
), |
|
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
|
) |
|
|
|
|
|
|
|
def forward(self, tenInput): |
|
tenOne = self.netOne(tenInput) |
|
tenTwo = self.netTwo(tenOne) |
|
tenThr = self.netThr(tenTwo) |
|
tenFou = self.netFou(tenThr) |
|
tenFiv = self.netFiv(tenFou) |
|
tenSix = self.netSix(tenFiv) |
|
|
|
return [tenOne, tenTwo, tenThr, tenFou, tenFiv, tenSix] |
|
|
|
|
|
|
|
|
|
|
|
class Decoder(torch.nn.Module): |
|
def __init__(self, intLevel): |
|
super(Decoder, self).__init__() |
|
|
|
intPrevious = [ |
|
None, |
|
None, |
|
81 + 32 + 2 + 2, |
|
81 + 64 + 2 + 2, |
|
81 + 96 + 2 + 2, |
|
81 + 128 + 2 + 2, |
|
81, |
|
None, |
|
][intLevel + 1] |
|
intCurrent = [ |
|
None, |
|
None, |
|
81 + 32 + 2 + 2, |
|
81 + 64 + 2 + 2, |
|
81 + 96 + 2 + 2, |
|
81 + 128 + 2 + 2, |
|
81, |
|
None, |
|
][intLevel + 0] |
|
|
|
if intLevel < 6: |
|
self.netUpflow = torch.nn.ConvTranspose2d( |
|
in_channels=2, |
|
out_channels=2, |
|
kernel_size=4, |
|
stride=2, |
|
padding=1, |
|
) |
|
if intLevel < 6: |
|
self.netUpfeat = torch.nn.ConvTranspose2d( |
|
in_channels=intPrevious + 128 + 128 + 96 + 64 + 32, |
|
out_channels=2, |
|
kernel_size=4, |
|
stride=2, |
|
padding=1, |
|
) |
|
if intLevel < 6: |
|
self.fltBackwarp = [None, None, None, 5.0, 2.5, 1.25, 0.625, None][ |
|
intLevel + 1 |
|
] |
|
|
|
self.netOne = torch.nn.Sequential( |
|
torch.nn.Conv2d( |
|
in_channels=intCurrent, |
|
out_channels=128, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1, |
|
), |
|
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
|
) |
|
|
|
self.netTwo = torch.nn.Sequential( |
|
torch.nn.Conv2d( |
|
in_channels=intCurrent + 128, |
|
out_channels=128, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1, |
|
), |
|
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
|
) |
|
|
|
self.netThr = torch.nn.Sequential( |
|
torch.nn.Conv2d( |
|
in_channels=intCurrent + 128 + 128, |
|
out_channels=96, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1, |
|
), |
|
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
|
) |
|
|
|
self.netFou = torch.nn.Sequential( |
|
torch.nn.Conv2d( |
|
in_channels=intCurrent + 128 + 128 + 96, |
|
out_channels=64, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1, |
|
), |
|
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
|
) |
|
|
|
self.netFiv = torch.nn.Sequential( |
|
torch.nn.Conv2d( |
|
in_channels=intCurrent + 128 + 128 + 96 + 64, |
|
out_channels=32, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1, |
|
), |
|
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
|
) |
|
|
|
self.netSix = torch.nn.Sequential( |
|
torch.nn.Conv2d( |
|
in_channels=intCurrent + 128 + 128 + 96 + 64 + 32, |
|
out_channels=2, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1, |
|
) |
|
) |
|
|
|
|
|
|
|
def forward(self, tenFirst, tenSecond, objPrevious): |
|
tenFlow = None |
|
tenFeat = None |
|
|
|
if objPrevious is None: |
|
tenFlow = None |
|
tenFeat = None |
|
|
|
tenVolume = torch.nn.functional.leaky_relu( |
|
input=FunctionCorrelation( |
|
tenFirst=tenFirst, tenSecond=tenSecond |
|
), |
|
negative_slope=0.1, |
|
inplace=False, |
|
) |
|
|
|
tenFeat = torch.cat([tenVolume], 1) |
|
|
|
elif objPrevious is not None: |
|
tenFlow = self.netUpflow(objPrevious["tenFlow"]) |
|
tenFeat = self.netUpfeat(objPrevious["tenFeat"]) |
|
|
|
tenVolume = torch.nn.functional.leaky_relu( |
|
input=FunctionCorrelation( |
|
tenFirst=tenFirst, |
|
tenSecond=backwarp( |
|
tenInput=tenSecond, tenFlow=tenFlow * self.fltBackwarp |
|
), |
|
), |
|
negative_slope=0.1, |
|
inplace=False, |
|
) |
|
|
|
tenFeat = torch.cat([tenVolume, tenFirst, tenFlow, tenFeat], 1) |
|
|
|
|
|
|
|
tenFeat = torch.cat([self.netOne(tenFeat), tenFeat], 1) |
|
tenFeat = torch.cat([self.netTwo(tenFeat), tenFeat], 1) |
|
tenFeat = torch.cat([self.netThr(tenFeat), tenFeat], 1) |
|
tenFeat = torch.cat([self.netFou(tenFeat), tenFeat], 1) |
|
tenFeat = torch.cat([self.netFiv(tenFeat), tenFeat], 1) |
|
|
|
tenFlow = self.netSix(tenFeat) |
|
|
|
return {"tenFlow": tenFlow, "tenFeat": tenFeat} |
|
|
|
|
|
|
|
|
|
|
|
class Refiner(torch.nn.Module): |
|
def __init__(self): |
|
super(Refiner, self).__init__() |
|
|
|
self.netMain = torch.nn.Sequential( |
|
torch.nn.Conv2d( |
|
in_channels=81 + 32 + 2 + 2 + 128 + 128 + 96 + 64 + 32, |
|
out_channels=128, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1, |
|
dilation=1, |
|
), |
|
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
|
torch.nn.Conv2d( |
|
in_channels=128, |
|
out_channels=128, |
|
kernel_size=3, |
|
stride=1, |
|
padding=2, |
|
dilation=2, |
|
), |
|
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
|
torch.nn.Conv2d( |
|
in_channels=128, |
|
out_channels=128, |
|
kernel_size=3, |
|
stride=1, |
|
padding=4, |
|
dilation=4, |
|
), |
|
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
|
torch.nn.Conv2d( |
|
in_channels=128, |
|
out_channels=96, |
|
kernel_size=3, |
|
stride=1, |
|
padding=8, |
|
dilation=8, |
|
), |
|
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
|
torch.nn.Conv2d( |
|
in_channels=96, |
|
out_channels=64, |
|
kernel_size=3, |
|
stride=1, |
|
padding=16, |
|
dilation=16, |
|
), |
|
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
|
torch.nn.Conv2d( |
|
in_channels=64, |
|
out_channels=32, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1, |
|
dilation=1, |
|
), |
|
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1), |
|
torch.nn.Conv2d( |
|
in_channels=32, |
|
out_channels=2, |
|
kernel_size=3, |
|
stride=1, |
|
padding=1, |
|
dilation=1, |
|
), |
|
) |
|
|
|
|
|
|
|
def forward(self, tenInput): |
|
return self.netMain(tenInput) |
|
|
|
|
|
|
|
|
|
|
|
self.netExtractor = Extractor() |
|
|
|
self.netTwo = Decoder(2) |
|
self.netThr = Decoder(3) |
|
self.netFou = Decoder(4) |
|
self.netFiv = Decoder(5) |
|
self.netSix = Decoder(6) |
|
|
|
self.netRefiner = Refiner() |
|
|
|
self.load_state_dict( |
|
{ |
|
strKey.replace("module", "net"): tenWeight |
|
for strKey, tenWeight in torch.hub.load_state_dict_from_url( |
|
url="http://content.sniklaus.com/github/pytorch-pwc/network-" |
|
+ "default" |
|
+ ".pytorch", |
|
model_dir=get_ckpt_container_path(MODEL_TYPE) |
|
).items() |
|
} |
|
) |
|
|
|
|
|
|
|
def forward(self, tenFirst, tenSecond, *args): |
|
|
|
if len(args) == 0: |
|
tenFirst = self.netExtractor(tenFirst) |
|
tenSecond = self.netExtractor(tenSecond) |
|
else: |
|
tenFirst, tenSecond = args |
|
|
|
objEstimate = self.netSix(tenFirst[-1], tenSecond[-1], None) |
|
objEstimate = self.netFiv(tenFirst[-2], tenSecond[-2], objEstimate) |
|
objEstimate = self.netFou(tenFirst[-3], tenSecond[-3], objEstimate) |
|
objEstimate = self.netThr(tenFirst[-4], tenSecond[-4], objEstimate) |
|
objEstimate = self.netTwo(tenFirst[-5], tenSecond[-5], objEstimate) |
|
|
|
return objEstimate["tenFlow"] + self.netRefiner(objEstimate["tenFeat"]) |
|
|
|
|
|
|
|
def extract_pyramid(self, tenFirst, tenSecond): |
|
return self.netExtractor(tenFirst), self.netExtractor(tenSecond) |
|
|
|
def extract_pyramid_single(self, tenFirst): |
|
return self.netExtractor(tenFirst) |
|
|
|
|
|
|
|
|
|
netNetwork = None |
|
|
|
|
|
|
|
|
|
def estimate(tenFirst, tenSecond): |
|
global netNetwork |
|
|
|
if netNetwork is None: |
|
netNetwork = Network().cuda().eval() |
|
|
|
|
|
assert tenFirst.shape[1] == tenSecond.shape[1] |
|
assert tenFirst.shape[2] == tenSecond.shape[2] |
|
|
|
intWidth = tenFirst.shape[2] |
|
intHeight = tenFirst.shape[1] |
|
|
|
assert ( |
|
intWidth == 1024 |
|
) |
|
assert ( |
|
intHeight == 436 |
|
) |
|
|
|
tenPreprocessedFirst = tenFirst.cuda().view(1, 3, intHeight, intWidth) |
|
tenPreprocessedSecond = tenSecond.cuda().view(1, 3, intHeight, intWidth) |
|
|
|
intPreprocessedWidth = int(math.floor(math.ceil(intWidth / 64.0) * 64.0)) |
|
intPreprocessedHeight = int(math.floor(math.ceil(intHeight / 64.0) * 64.0)) |
|
|
|
tenPreprocessedFirst = torch.nn.functional.interpolate( |
|
input=tenPreprocessedFirst, |
|
size=(intPreprocessedHeight, intPreprocessedWidth), |
|
mode="bilinear", |
|
align_corners=False, |
|
) |
|
tenPreprocessedSecond = torch.nn.functional.interpolate( |
|
input=tenPreprocessedSecond, |
|
size=(intPreprocessedHeight, intPreprocessedWidth), |
|
mode="bilinear", |
|
align_corners=False, |
|
) |
|
|
|
tenFlow = 20.0 * torch.nn.functional.interpolate( |
|
input=netNetwork(tenPreprocessedFirst, tenPreprocessedSecond), |
|
size=(intHeight, intWidth), |
|
mode="bilinear", |
|
align_corners=False, |
|
) |
|
|
|
tenFlow[:, 0, :, :] *= float(intWidth) / float(intPreprocessedWidth) |
|
tenFlow[:, 1, :, :] *= float(intHeight) / float(intPreprocessedHeight) |
|
|
|
return tenFlow[0, :, :, :].cpu() |
|
|
|
|
|
|
|
|
|
|
|
class UNet3d_18(nn.Module): |
|
def __init__(self, channels=[32, 64, 96, 128], bn=True): |
|
super(UNet3d_18, self).__init__() |
|
growth = 2 |
|
upmode = "transpose" |
|
|
|
self.channels = channels |
|
|
|
self.lrelu = nn.LeakyReLU(0.2, True) |
|
|
|
self.encoder = r3d_18(bn=bn, channels=channels) |
|
|
|
self.decoder = nn.Sequential( |
|
Conv_3d( |
|
channels[::-1][0], |
|
channels[::-1][1], |
|
kernel_size=3, |
|
padding=1, |
|
bias=True, |
|
), |
|
upConv3D( |
|
channels[::-1][1] * growth, |
|
channels[::-1][2], |
|
kernel_size=(3, 4, 4), |
|
stride=(1, 2, 2), |
|
padding=(1, 1, 1), |
|
upmode=upmode, |
|
), |
|
upConv3D( |
|
channels[::-1][2] * growth, |
|
channels[::-1][3], |
|
kernel_size=(3, 4, 4), |
|
stride=(1, 2, 2), |
|
padding=(1, 1, 1), |
|
upmode=upmode, |
|
), |
|
Conv_3d( |
|
channels[::-1][3] * growth, |
|
channels[::-1][3], |
|
kernel_size=3, |
|
padding=1, |
|
bias=True, |
|
), |
|
upConv3D( |
|
channels[::-1][3] * growth, |
|
channels[::-1][3], |
|
kernel_size=(3, 4, 4), |
|
stride=(1, 2, 2), |
|
padding=(1, 1, 1), |
|
upmode=upmode, |
|
), |
|
) |
|
|
|
self.feature_fuse = nn.Sequential( |
|
*( |
|
[ |
|
nn.Conv2d( |
|
channels[::-1][3] * 5, |
|
channels[::-1][3], |
|
kernel_size=1, |
|
stride=1, |
|
bias=False, |
|
) |
|
] |
|
+ [nn.BatchNorm2d(channels[::-1][3]) if bn else Identity] |
|
) |
|
) |
|
|
|
self.outconv = nn.Sequential( |
|
nn.ReflectionPad2d(3), |
|
nn.Conv2d(channels[::-1][3], 3, kernel_size=7, stride=1, padding=0), |
|
) |
|
|
|
def forward(self, im1, im3, im5, im7, im4_tilde): |
|
images = torch.stack((im1, im3, im4_tilde, im5, im7), dim=2) |
|
|
|
x_0, x_1, x_2, x_3, x_4 = self.encoder(images) |
|
|
|
dx_3 = self.lrelu(self.decoder[0](x_4)) |
|
dx_3 = torch.cat([dx_3, x_3], dim=1) |
|
|
|
dx_2 = self.lrelu(self.decoder[1](dx_3)) |
|
dx_2 = torch.cat([dx_2, x_2], dim=1) |
|
|
|
dx_1 = self.lrelu(self.decoder[2](dx_2)) |
|
dx_1 = torch.cat([dx_1, x_1], dim=1) |
|
|
|
dx_0 = self.lrelu(self.decoder[3](dx_1)) |
|
dx_0 = torch.cat([dx_0, x_0], dim=1) |
|
|
|
dx_out = self.lrelu(self.decoder[4](dx_0)) |
|
dx_out = torch.cat(torch.unbind(dx_out, 2), 1) |
|
|
|
out = self.lrelu(self.feature_fuse(dx_out)) |
|
out = self.outconv(out) |
|
|
|
return out |
|
|
|
|
|
class KernelEstimation(torch.nn.Module): |
|
def __init__(self, kernel_size): |
|
super(KernelEstimation, self).__init__() |
|
self.kernel_size = kernel_size |
|
|
|
def Subnet_offset(ks): |
|
return torch.nn.Sequential( |
|
torch.nn.Conv2d( |
|
in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1 |
|
), |
|
torch.nn.ReLU(inplace=False), |
|
torch.nn.Conv2d( |
|
in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1 |
|
), |
|
torch.nn.ReLU(inplace=False), |
|
torch.nn.Conv2d( |
|
in_channels=64, out_channels=ks, kernel_size=3, stride=1, padding=1 |
|
), |
|
torch.nn.ReLU(inplace=False), |
|
torch.nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True), |
|
torch.nn.Conv2d( |
|
in_channels=ks, out_channels=ks, kernel_size=3, stride=1, padding=1 |
|
), |
|
) |
|
|
|
def Subnet_weight(ks): |
|
return torch.nn.Sequential( |
|
torch.nn.Conv2d( |
|
in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1 |
|
), |
|
torch.nn.ReLU(inplace=False), |
|
torch.nn.Conv2d( |
|
in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1 |
|
), |
|
torch.nn.ReLU(inplace=False), |
|
torch.nn.Conv2d( |
|
in_channels=64, out_channels=ks, kernel_size=3, stride=1, padding=1 |
|
), |
|
torch.nn.ReLU(inplace=False), |
|
torch.nn.Upsample(scale_factor=2, mode="bilinear", align_corners=True), |
|
torch.nn.Conv2d( |
|
in_channels=ks, out_channels=ks, kernel_size=3, stride=1, padding=1 |
|
), |
|
torch.nn.Softmax(dim=1), |
|
) |
|
|
|
def Subnet_offset_ds(ks): |
|
return torch.nn.Sequential( |
|
torch.nn.Conv2d( |
|
in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1 |
|
), |
|
torch.nn.ReLU(inplace=False), |
|
torch.nn.Conv2d( |
|
in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1 |
|
), |
|
torch.nn.ReLU(inplace=False), |
|
torch.nn.Conv2d( |
|
in_channels=64, out_channels=ks, kernel_size=3, stride=1, padding=1 |
|
), |
|
) |
|
|
|
def Subnet_weight_ds(ks): |
|
return torch.nn.Sequential( |
|
torch.nn.Conv2d( |
|
in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1 |
|
), |
|
torch.nn.ReLU(inplace=False), |
|
torch.nn.Conv2d( |
|
in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1 |
|
), |
|
torch.nn.ReLU(inplace=False), |
|
torch.nn.Conv2d( |
|
in_channels=64, out_channels=ks, kernel_size=3, stride=1, padding=1 |
|
), |
|
torch.nn.Softmax(dim=1), |
|
) |
|
|
|
def Subnet_offset_us(ks): |
|
return torch.nn.Sequential( |
|
torch.nn.Conv2d( |
|
in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1 |
|
), |
|
torch.nn.ReLU(inplace=False), |
|
torch.nn.Conv2d( |
|
in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1 |
|
), |
|
torch.nn.ReLU(inplace=False), |
|
torch.nn.Conv2d( |
|
in_channels=64, out_channels=ks, kernel_size=3, stride=1, padding=1 |
|
), |
|
torch.nn.ReLU(inplace=False), |
|
torch.nn.Upsample(scale_factor=4, mode="bilinear", align_corners=True), |
|
torch.nn.Conv2d( |
|
in_channels=ks, out_channels=ks, kernel_size=3, stride=1, padding=1 |
|
), |
|
) |
|
|
|
def Subnet_weight_us(ks): |
|
return torch.nn.Sequential( |
|
torch.nn.Conv2d( |
|
in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1 |
|
), |
|
torch.nn.ReLU(inplace=False), |
|
torch.nn.Conv2d( |
|
in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1 |
|
), |
|
torch.nn.ReLU(inplace=False), |
|
torch.nn.Conv2d( |
|
in_channels=64, out_channels=ks, kernel_size=3, stride=1, padding=1 |
|
), |
|
torch.nn.ReLU(inplace=False), |
|
torch.nn.Upsample(scale_factor=4, mode="bilinear", align_corners=True), |
|
torch.nn.Conv2d( |
|
in_channels=ks, out_channels=ks, kernel_size=3, stride=1, padding=1 |
|
), |
|
torch.nn.Softmax(dim=1), |
|
) |
|
|
|
self.moduleWeight1_ds = Subnet_weight_ds(self.kernel_size**2) |
|
self.moduleAlpha1_ds = Subnet_offset_ds(self.kernel_size**2) |
|
self.moduleBeta1_ds = Subnet_offset_ds(self.kernel_size**2) |
|
self.moduleWeight2_ds = Subnet_weight_ds(self.kernel_size**2) |
|
self.moduleAlpha2_ds = Subnet_offset_ds(self.kernel_size**2) |
|
self.moduleBeta2_ds = Subnet_offset_ds(self.kernel_size**2) |
|
|
|
self.moduleWeight1 = Subnet_weight(self.kernel_size**2) |
|
self.moduleAlpha1 = Subnet_offset(self.kernel_size**2) |
|
self.moduleBeta1 = Subnet_offset(self.kernel_size**2) |
|
self.moduleWeight2 = Subnet_weight(self.kernel_size**2) |
|
self.moduleAlpha2 = Subnet_offset(self.kernel_size**2) |
|
self.moduleBeta2 = Subnet_offset(self.kernel_size**2) |
|
|
|
self.moduleWeight1_us = Subnet_weight_us(self.kernel_size**2) |
|
self.moduleAlpha1_us = Subnet_offset_us(self.kernel_size**2) |
|
self.moduleBeta1_us = Subnet_offset_us(self.kernel_size**2) |
|
self.moduleWeight2_us = Subnet_weight_us(self.kernel_size**2) |
|
self.moduleAlpha2_us = Subnet_offset_us(self.kernel_size**2) |
|
self.moduleBeta2_us = Subnet_offset_us(self.kernel_size**2) |
|
|
|
def forward(self, tensorCombine): |
|
|
|
Weight1_ds = self.moduleWeight1_ds(tensorCombine) |
|
Weight1 = self.moduleWeight1(tensorCombine) |
|
Weight1_us = self.moduleWeight1_us(tensorCombine) |
|
Alpha1_ds = self.moduleAlpha1_ds(tensorCombine) |
|
Alpha1 = self.moduleAlpha1(tensorCombine) |
|
Alpha1_us = self.moduleAlpha1_us(tensorCombine) |
|
Beta1_ds = self.moduleBeta1_ds(tensorCombine) |
|
Beta1 = self.moduleBeta1(tensorCombine) |
|
Beta1_us = self.moduleBeta1_us(tensorCombine) |
|
|
|
|
|
Weight2_ds = self.moduleWeight2_ds(tensorCombine) |
|
Weight2 = self.moduleWeight2(tensorCombine) |
|
Weight2_us = self.moduleWeight2_us(tensorCombine) |
|
Alpha2_ds = self.moduleAlpha2_ds(tensorCombine) |
|
Alpha2 = self.moduleAlpha2(tensorCombine) |
|
Alpha2_us = self.moduleAlpha2_us(tensorCombine) |
|
Beta2_ds = self.moduleBeta2_ds(tensorCombine) |
|
Beta2 = self.moduleBeta2(tensorCombine) |
|
Beta2_us = self.moduleBeta2_us(tensorCombine) |
|
|
|
return ( |
|
Weight1_ds, |
|
Alpha1_ds, |
|
Beta1_ds, |
|
Weight2_ds, |
|
Alpha2_ds, |
|
Beta2_ds, |
|
Weight1, |
|
Alpha1, |
|
Beta1, |
|
Weight2, |
|
Alpha2, |
|
Beta2, |
|
Weight1_us, |
|
Alpha1_us, |
|
Beta1_us, |
|
Weight2_us, |
|
Alpha2_us, |
|
Beta2_us, |
|
) |
|
|
|
|
|
class STMFNet_Model(torch.nn.Module): |
|
def __init__(self): |
|
super(STMFNet_Model, self).__init__() |
|
|
|
class Metric(torch.nn.Module): |
|
def __init__(self): |
|
super(Metric, self).__init__() |
|
self.paramScale = torch.nn.Parameter(-torch.ones(1, 1, 1, 1)) |
|
|
|
def forward(self, tenFirst, tenSecond, tenFlow): |
|
return self.paramScale * F.l1_loss( |
|
input=tenFirst, |
|
target=backwarp(tenSecond, tenFlow), |
|
reduction="none", |
|
).mean(1, True) |
|
|
|
self.kernel_size = 5 |
|
self.dilation = 1 |
|
self.featc = [64, 128, 256, 512] |
|
self.featnorm = "batch" |
|
self.finetune_pwc = False |
|
|
|
self.kernel_pad = int(((self.kernel_size - 1) * self.dilation) / 2.0) |
|
|
|
self.feature_extractor = UMultiScaleResNext( |
|
self.featc, norm_layer=self.featnorm |
|
) |
|
|
|
self.get_kernel = KernelEstimation(self.kernel_size) |
|
|
|
self.modulePad = torch.nn.ReplicationPad2d( |
|
[self.kernel_pad, self.kernel_pad, self.kernel_pad, self.kernel_pad] |
|
) |
|
|
|
self.moduleAdaCoF = FunctionAdaCoF.apply |
|
|
|
self.gauss_kernel = torch.nn.Parameter( |
|
gaussian_kernel(5, 0.5).repeat(3, 1, 1, 1), requires_grad=False |
|
) |
|
|
|
self.upsampler = Upsampler_8tap() |
|
|
|
self.scale_synthesis = MIMOGridNet( |
|
(6, 6 + 6, 6), (3,), grid_chs=(32, 64, 96), n_row=3, n_col=4, outrow=(1,) |
|
) |
|
|
|
self.flow_estimator = PWCNet() |
|
|
|
self.softsplat = ModuleSoftsplat(strType="softmax") |
|
|
|
self.metric = Metric() |
|
|
|
self.dyntex_generator = UNet3d_18(bn=self.featnorm) |
|
|
|
|
|
if not self.finetune_pwc: |
|
for param in self.flow_estimator.parameters(): |
|
param.requires_grad = False |
|
|
|
def forward(self, I0, I1, I2, I3): |
|
h0 = int(list(I1.size())[2]) |
|
w0 = int(list(I1.size())[3]) |
|
h2 = int(list(I2.size())[2]) |
|
w2 = int(list(I2.size())[3]) |
|
if h0 != h2 or w0 != w2: |
|
sys.exit("Frame sizes do not match") |
|
|
|
h_padded = False |
|
w_padded = False |
|
if h0 % 128 != 0: |
|
pad_h = 128 - (h0 % 128) |
|
I0 = F.pad(I0, (0, 0, 0, pad_h), mode="reflect") |
|
I1 = F.pad(I1, (0, 0, 0, pad_h), mode="reflect") |
|
I2 = F.pad(I2, (0, 0, 0, pad_h), mode="reflect") |
|
I3 = F.pad(I3, (0, 0, 0, pad_h), mode="reflect") |
|
h_padded = True |
|
|
|
if w0 % 128 != 0: |
|
pad_w = 128 - (w0 % 128) |
|
I0 = F.pad(I0, (0, pad_w, 0, 0), mode="reflect") |
|
I1 = F.pad(I1, (0, pad_w, 0, 0), mode="reflect") |
|
I2 = F.pad(I2, (0, pad_w, 0, 0), mode="reflect") |
|
I3 = F.pad(I3, (0, pad_w, 0, 0), mode="reflect") |
|
w_padded = True |
|
|
|
feats = self.feature_extractor(moduleNormalize(I1), moduleNormalize(I2)) |
|
kernelest = self.get_kernel(feats) |
|
Weight1_ds, Alpha1_ds, Beta1_ds, Weight2_ds, Alpha2_ds, Beta2_ds = kernelest[:6] |
|
Weight1, Alpha1, Beta1, Weight2, Alpha2, Beta2 = kernelest[6:12] |
|
Weight1_us, Alpha1_us, Beta1_us, Weight2_us, Alpha2_us, Beta2_us = kernelest[ |
|
12: |
|
] |
|
|
|
|
|
tensorAdaCoF1 = ( |
|
self.moduleAdaCoF(self.modulePad(I1), Weight1, Alpha1, Beta1, self.dilation) |
|
* 1.0 |
|
) |
|
tensorAdaCoF2 = ( |
|
self.moduleAdaCoF(self.modulePad(I2), Weight2, Alpha2, Beta2, self.dilation) |
|
* 1.0 |
|
) |
|
|
|
|
|
c, h, w = I1.shape[1:] |
|
p = (self.gauss_kernel.shape[-1] - 1) // 2 |
|
I1_blur = F.conv2d( |
|
F.pad(I1, pad=(p, p, p, p), mode="reflect"), self.gauss_kernel, groups=c |
|
) |
|
I2_blur = F.conv2d( |
|
F.pad(I2, pad=(p, p, p, p), mode="reflect"), self.gauss_kernel, groups=c |
|
) |
|
I1_ds = F.interpolate( |
|
I1_blur, size=(h // 2, w // 2), mode="bilinear", align_corners=False |
|
) |
|
I2_ds = F.interpolate( |
|
I2_blur, size=(h // 2, w // 2), mode="bilinear", align_corners=False |
|
) |
|
tensorAdaCoF1_ds = ( |
|
self.moduleAdaCoF( |
|
self.modulePad(I1_ds), Weight1_ds, Alpha1_ds, Beta1_ds, self.dilation |
|
) |
|
* 1.0 |
|
) |
|
tensorAdaCoF2_ds = ( |
|
self.moduleAdaCoF( |
|
self.modulePad(I2_ds), Weight2_ds, Alpha2_ds, Beta2_ds, self.dilation |
|
) |
|
* 1.0 |
|
) |
|
|
|
|
|
I1_us = self.upsampler(I1) |
|
I2_us = self.upsampler(I2) |
|
tensorAdaCoF1_us = ( |
|
self.moduleAdaCoF( |
|
self.modulePad(I1_us), Weight1_us, Alpha1_us, Beta1_us, self.dilation |
|
) |
|
* 1.0 |
|
) |
|
tensorAdaCoF2_us = ( |
|
self.moduleAdaCoF( |
|
self.modulePad(I2_us), Weight2_us, Alpha2_us, Beta2_us, self.dilation |
|
) |
|
* 1.0 |
|
) |
|
|
|
|
|
pyramid0, pyramid2 = self.flow_estimator.extract_pyramid(I1, I2) |
|
flow_0_2 = 20 * self.flow_estimator(I1, I2, pyramid0, pyramid2) |
|
flow_0_2 = F.interpolate( |
|
flow_0_2, size=(h, w), mode="bilinear", align_corners=False |
|
) |
|
flow_2_0 = 20 * self.flow_estimator(I2, I1, pyramid2, pyramid0) |
|
flow_2_0 = F.interpolate( |
|
flow_2_0, size=(h, w), mode="bilinear", align_corners=False |
|
) |
|
metric_0_2 = self.metric(I1, I2, flow_0_2) |
|
metric_2_0 = self.metric(I2, I1, flow_2_0) |
|
tensorSoftsplat0 = self.softsplat(I1, 0.5 * flow_0_2, metric_0_2) |
|
tensorSoftsplat2 = self.softsplat(I2, 0.5 * flow_2_0, metric_2_0) |
|
|
|
|
|
tensorCombine_us = torch.cat([tensorAdaCoF1_us, tensorAdaCoF2_us], dim=1) |
|
tensorCombine = torch.cat( |
|
[tensorAdaCoF1, tensorAdaCoF2, tensorSoftsplat0, tensorSoftsplat2], dim=1 |
|
) |
|
tensorCombine_ds = torch.cat([tensorAdaCoF1_ds, tensorAdaCoF2_ds], dim=1) |
|
output_tilde = self.scale_synthesis( |
|
tensorCombine_us, tensorCombine, tensorCombine_ds |
|
)[0] |
|
|
|
|
|
dyntex = self.dyntex_generator(I0, I1, I2, I3, output_tilde) |
|
output = output_tilde + dyntex |
|
|
|
if h_padded: |
|
output = output[:, :, 0:h0, :] |
|
if w_padded: |
|
output = output[:, :, :, 0:w0] |
|
|
|
if self.training: |
|
return {"frame1": output} |
|
else: |
|
return output |