|
""" |
|
https://github.com/98mxr/IFUNet/blob/main/model/IFUNet.py |
|
https://github.com/98mxr/IFUNet/blob/main/model/cbam.py |
|
https://github.com/98mxr/IFUNet/blob/main/model/warplayer.py |
|
https://github.com/98mxr/IFUNet/blob/5be535c8cff66d6fa1967252685719df4c0620e4/model/RIFE.py |
|
https://github.com/98mxr/IFUNet/blob/main/model/rrdb.py |
|
https://github.com/98mxr/IFUNet/blob/main/model/ResynNet.py |
|
""" |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from comfy.model_management import get_torch_device |
|
|
|
backwarp_tenGrid = {} |
|
device = get_torch_device() |
|
|
|
|
|
def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): |
|
return nn.Sequential( |
|
nn.Conv2d( |
|
in_planes, |
|
out_planes, |
|
kernel_size=kernel_size, |
|
stride=stride, |
|
padding=padding, |
|
dilation=dilation, |
|
bias=True, |
|
), |
|
nn.PReLU(out_planes), |
|
) |
|
|
|
|
|
def conv_bn(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): |
|
return nn.Sequential( |
|
nn.Conv2d( |
|
in_planes, |
|
out_planes, |
|
kernel_size=kernel_size, |
|
stride=stride, |
|
padding=padding, |
|
dilation=dilation, |
|
bias=False, |
|
), |
|
nn.BatchNorm2d(out_planes), |
|
nn.PReLU(out_planes), |
|
) |
|
|
|
|
|
class DegCNN(nn.Module): |
|
def __init__(self): |
|
super(DegCNN, self).__init__() |
|
self.conv0 = conv(3, 32, 3, 2, 1) |
|
self.conv1 = conv(32, 32, 3, 2, 1) |
|
self.conv2 = conv(32, 32, 3, 2, 1) |
|
self.conv3 = conv(32, 32, 3, 2, 1) |
|
self.deconv = nn.Sequential( |
|
nn.Dropout2d(0.95), |
|
nn.ConvTranspose2d(4 * 32, 32, 4, 2, 1), |
|
nn.PReLU(32), |
|
nn.Conv2d(32, 3, 3, 1, 1), |
|
nn.Sigmoid(), |
|
) |
|
|
|
def forward(self, x): |
|
f0 = self.conv0(x) |
|
f1 = self.conv1(f0) |
|
f2 = self.conv2(f1) |
|
f3 = self.conv3(f2) |
|
f1 = F.interpolate(f1, scale_factor=2.0, mode="bilinear", align_corners=False) |
|
f2 = F.interpolate(f2, scale_factor=4.0, mode="bilinear", align_corners=False) |
|
f3 = F.interpolate(f3, scale_factor=8.0, mode="bilinear", align_corners=False) |
|
return self.deconv(torch.cat((f0, f1, f2, f3), 1)) |
|
|
|
|
|
class FlowBlock(nn.Module): |
|
def __init__(self, in_planes, c=64): |
|
super(FlowBlock, self).__init__() |
|
self.conv0 = nn.Sequential( |
|
conv_bn(in_planes, c // 2, 3, 2, 1), |
|
conv_bn(c // 2, c, 3, 2, 1), |
|
conv_bn(c, 2 * c, 3, 2, 1), |
|
) |
|
self.convblock = nn.Sequential( |
|
conv_bn(2 * c, 2 * c), |
|
conv_bn(2 * c, 2 * c), |
|
conv_bn(2 * c, 2 * c), |
|
conv_bn(2 * c, 2 * c), |
|
conv_bn(2 * c, 2 * c), |
|
conv_bn(2 * c, 2 * c), |
|
) |
|
self.lastconv = nn.ConvTranspose2d(2 * c, 4, 4, 2, 1) |
|
|
|
def forward(self, x, flow, scale=1): |
|
x = F.interpolate( |
|
x, scale_factor=1.0 / scale, mode="bilinear", align_corners=False |
|
) |
|
if flow is not None: |
|
flow = ( |
|
F.interpolate( |
|
flow, scale_factor=1.0 / scale, mode="bilinear", align_corners=False |
|
) |
|
* 1.0 |
|
/ scale |
|
) |
|
x = torch.cat((x, flow), 1) |
|
feat = self.conv0(x) |
|
feat = self.convblock(feat) + feat |
|
tmp = self.lastconv(feat) |
|
tmp = F.interpolate( |
|
tmp, scale_factor=scale * 4, mode="bilinear", align_corners=False |
|
) |
|
flow = tmp[:, :2] * scale * 4 |
|
mask = tmp[:, 2:3] |
|
return flow, mask |
|
|
|
|
|
class ResynNet(nn.Module): |
|
def __init__(self): |
|
super(ResynNet, self).__init__() |
|
self.block0 = FlowBlock(6, c=128) |
|
self.block1 = FlowBlock(12, c=128) |
|
self.block2 = FlowBlock(12, c=128) |
|
self.degrad = DegCNN() |
|
|
|
self.context0 = nn.Sequential( |
|
conv(3, 16, 3, 2, 1), |
|
conv(16, 32, 3, 2, 1), |
|
) |
|
self.context1 = nn.Sequential( |
|
conv(3, 16, 3, 2, 1), |
|
conv(16, 32, 3, 2, 1), |
|
) |
|
self.decode = nn.Sequential( |
|
nn.ConvTranspose2d(64, 32, 4, 2, 1), |
|
nn.ConvTranspose2d(32, 3, 4, 2, 1), |
|
nn.Tanh(), |
|
) |
|
|
|
def calflow(self, img0, lowres, scale): |
|
flow = None |
|
stu = [self.block0, self.block1, self.block2] |
|
for i in range(3): |
|
if flow is not None: |
|
flow_d, mask_d = stu[i]( |
|
torch.cat((img0, lowres, warped_img0, mask), 1), |
|
flow, |
|
scale=scale[i], |
|
) |
|
flow = flow + flow_d |
|
mask = mask + mask_d |
|
else: |
|
flow, mask = stu[i](torch.cat((img0, lowres), 1), None, scale=scale[i]) |
|
warped_img0 = warp(img0, flow) |
|
flow_down = ( |
|
F.interpolate(flow, scale_factor=0.25, mode="bilinear", align_corners=False) |
|
* 0.25 |
|
) |
|
c0 = warp(self.context0(img0), flow_down) |
|
c1 = self.context1(warped_img0) |
|
warped_img0 = warped_img0 + self.decode(torch.cat((c0, c1), 1)) |
|
return flow, mask, torch.clamp(warped_img0, 0, 1) |
|
|
|
def forward( |
|
self, x, deg=None, gt=None, scale=[4, 2, 1], training=False, blend=True |
|
): |
|
if training: |
|
deg = self.degrad(gt) |
|
loss_cons = (gt - deg).abs().mean() |
|
else: |
|
loss_cons = torch.tensor([0]) |
|
img_list = [] |
|
N = x.shape[1] // 3 |
|
for i in range(N): |
|
img_list.append(x[:, i * 3 : i * 3 + 3]) |
|
warped_list = [] |
|
merged = [] |
|
mask_list = [] |
|
flow_list = [] |
|
for i in range(N): |
|
f, m, img = self.calflow(img_list[i], deg.detach(), scale) |
|
mask_list.append(m) |
|
warped_list.append(img) |
|
flow_list.append(f) |
|
if blend: |
|
N += 1 |
|
mask_list.append(m * 0) |
|
warped_list.append(deg) |
|
mask = F.softmax(torch.clamp(torch.cat(mask_list, 1), -4, 4), dim=1) |
|
merged = 0 |
|
for i in range(N): |
|
merged += warped_list[i] * mask[:, i : i + 1] |
|
return merged, loss_cons |
|
|
|
|
|
def make_layer(basic_block, num_basic_block, **kwarg): |
|
"""Make layers by stacking the same blocks. |
|
Args: |
|
basic_block (nn.module): nn.module class for basic block. |
|
num_basic_block (int): number of blocks. |
|
Returns: |
|
nn.Sequential: Stacked blocks in nn.Sequential. |
|
""" |
|
layers = [] |
|
for _ in range(num_basic_block): |
|
layers.append(basic_block(**kwarg)) |
|
return nn.Sequential(*layers) |
|
|
|
|
|
class ResidualDenseBlock(nn.Module): |
|
"""Residual Dense Block. |
|
|
|
Used in RRDB block in ESRGAN. |
|
|
|
Args: |
|
num_feat (int): Channel number of intermediate features. |
|
num_grow_ch (int): Channels for each growth. |
|
""" |
|
|
|
def __init__(self, num_feat=64, num_grow_ch=32): |
|
super(ResidualDenseBlock, self).__init__() |
|
self.conv1 = nn.Conv2d(num_feat, num_grow_ch, 3, 1, 1) |
|
self.conv2 = nn.Conv2d(num_feat + num_grow_ch, num_grow_ch, 3, 1, 1) |
|
self.conv3 = nn.Conv2d(num_feat + 2 * num_grow_ch, num_grow_ch, 3, 1, 1) |
|
self.conv4 = nn.Conv2d(num_feat + 3 * num_grow_ch, num_grow_ch, 3, 1, 1) |
|
self.conv5 = nn.Conv2d(num_feat + 4 * num_grow_ch, num_feat, 3, 1, 1) |
|
|
|
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) |
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x): |
|
x1 = self.lrelu(self.conv1(x)) |
|
x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1))) |
|
x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1))) |
|
x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1))) |
|
x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1)) |
|
|
|
|
|
return x5 * 0.2 + x |
|
|
|
|
|
class RRDB(nn.Module): |
|
"""Residual in Residual Dense Block. |
|
|
|
Used in RRDB-Net in ESRGAN. |
|
|
|
Args: |
|
num_feat (int): Channel number of intermediate features. |
|
num_grow_ch (int): Channels for each growth. |
|
""" |
|
|
|
def __init__(self, num_feat, num_grow_ch=32): |
|
super(RRDB, self).__init__() |
|
self.rdb1 = ResidualDenseBlock(num_feat, num_grow_ch) |
|
self.rdb2 = ResidualDenseBlock(num_feat, num_grow_ch) |
|
self.rdb3 = ResidualDenseBlock(num_feat, num_grow_ch) |
|
|
|
def forward(self, x): |
|
out = self.rdb1(x) |
|
out = self.rdb2(out) |
|
out = self.rdb3(out) |
|
|
|
|
|
return out * 0.2 + x |
|
|
|
|
|
class RRDBNet(nn.Module): |
|
"""Networks consisting of Residual in Residual Dense Block, which is used |
|
in ESRGAN. |
|
|
|
ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks. |
|
|
|
We extend ESRGAN for scale x2 and scale x1. |
|
Note: This is one option for scale 1, scale 2 in RRDBNet. |
|
We first employ the pixel-unshuffle (an inverse operation of pixelshuffle to reduce the spatial size |
|
and enlarge the channel size before feeding inputs into the main ESRGAN architecture. |
|
|
|
Args: |
|
num_in_ch (int): Channel number of inputs. |
|
num_out_ch (int): Channel number of outputs. |
|
num_feat (int): Channel number of intermediate features. |
|
Default: 64 |
|
num_block (int): Block number in the trunk network. Defaults: 23 |
|
num_grow_ch (int): Channels for each growth. Default: 32. |
|
""" |
|
|
|
def __init__( |
|
self, num_in_ch=16, num_out_ch=1, num_feat=64, num_block=6, num_grow_ch=32 |
|
): |
|
super(RRDBNet, self).__init__() |
|
self.conv_first = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1) |
|
self.body = make_layer( |
|
RRDB, num_block, num_feat=num_feat, num_grow_ch=num_grow_ch |
|
) |
|
self.conv_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1) |
|
|
|
self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) |
|
self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1) |
|
self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1) |
|
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1) |
|
|
|
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True) |
|
|
|
def forward(self, img0, img1, warped_img0, warped_img1, flow): |
|
x = torch.cat((img0, img1, warped_img0, warped_img1), 1) |
|
x = F.interpolate(x, scale_factor=0.25, mode="bilinear", align_corners=False) |
|
flow = ( |
|
F.interpolate(flow, scale_factor=0.25, mode="bilinear", align_corners=False) |
|
* 0.25 |
|
) |
|
feat = torch.cat((x, flow), 1) |
|
|
|
feat = self.conv_first(feat) |
|
body_feat = self.conv_body(self.body(feat)) |
|
feat = feat + body_feat |
|
|
|
feat = self.lrelu( |
|
self.conv_up1(F.interpolate(feat, scale_factor=2.0, mode="nearest")) |
|
) |
|
feat = self.lrelu( |
|
self.conv_up2(F.interpolate(feat, scale_factor=2.0, mode="nearest")) |
|
) |
|
out = self.conv_last(self.lrelu(self.conv_hr(feat))) |
|
|
|
out = torch.sigmoid(out) |
|
return out |
|
|
|
|
|
def warp(tenInput, tenFlow): |
|
k = (str(tenFlow.device), str(tenFlow.size())) |
|
if k not in backwarp_tenGrid: |
|
tenHorizontal = ( |
|
torch.linspace(-1.0, 1.0, tenFlow.shape[3], device=device) |
|
.view(1, 1, 1, tenFlow.shape[3]) |
|
.expand(tenFlow.shape[0], -1, tenFlow.shape[2], -1) |
|
) |
|
tenVertical = ( |
|
torch.linspace(-1.0, 1.0, tenFlow.shape[2], device=device) |
|
.view(1, 1, tenFlow.shape[2], 1) |
|
.expand(tenFlow.shape[0], -1, -1, tenFlow.shape[3]) |
|
) |
|
backwarp_tenGrid[k] = torch.cat([tenHorizontal, tenVertical], 1).to(device) |
|
|
|
tenFlow = torch.cat( |
|
[ |
|
tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0), |
|
tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0), |
|
], |
|
1, |
|
) |
|
|
|
g = (backwarp_tenGrid[k] + tenFlow).permute(0, 2, 3, 1) |
|
return torch.nn.functional.grid_sample( |
|
input=tenInput, |
|
grid=g, |
|
mode="bilinear", |
|
padding_mode="border", |
|
align_corners=True, |
|
) |
|
|
|
|
|
class BasicConv(nn.Module): |
|
def __init__( |
|
self, |
|
in_planes, |
|
out_planes, |
|
kernel_size, |
|
stride=1, |
|
padding=0, |
|
dilation=1, |
|
groups=1, |
|
relu=True, |
|
bn=True, |
|
bias=False, |
|
): |
|
super(BasicConv, self).__init__() |
|
self.out_channels = out_planes |
|
self.conv = nn.Conv2d( |
|
in_planes, |
|
out_planes, |
|
kernel_size=kernel_size, |
|
stride=stride, |
|
padding=padding, |
|
dilation=dilation, |
|
groups=groups, |
|
bias=bias, |
|
) |
|
self.bn = ( |
|
nn.BatchNorm2d(out_planes, eps=1e-5, momentum=0.01, affine=True) |
|
if bn |
|
else None |
|
) |
|
self.relu = nn.ReLU() if relu else None |
|
|
|
def forward(self, x): |
|
x = self.conv(x) |
|
if self.bn is not None: |
|
x = self.bn(x) |
|
if self.relu is not None: |
|
x = self.relu(x) |
|
return x |
|
|
|
|
|
class Flatten(nn.Module): |
|
def forward(self, x): |
|
return x.view(x.size(0), -1) |
|
|
|
|
|
class ChannelGate(nn.Module): |
|
def __init__(self, gate_channels, reduction_ratio=16, pool_types=["avg", "max"]): |
|
super(ChannelGate, self).__init__() |
|
self.gate_channels = gate_channels |
|
self.mlp = nn.Sequential( |
|
Flatten(), |
|
nn.Linear(gate_channels, gate_channels // reduction_ratio), |
|
nn.ReLU(), |
|
nn.Linear(gate_channels // reduction_ratio, gate_channels), |
|
) |
|
self.pool_types = pool_types |
|
|
|
def forward(self, x): |
|
channel_att_sum = None |
|
for pool_type in self.pool_types: |
|
if pool_type == "avg": |
|
avg_pool = F.avg_pool2d( |
|
x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)) |
|
) |
|
channel_att_raw = self.mlp(avg_pool) |
|
elif pool_type == "max": |
|
max_pool = F.max_pool2d( |
|
x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)) |
|
) |
|
channel_att_raw = self.mlp(max_pool) |
|
elif pool_type == "lp": |
|
lp_pool = F.lp_pool2d( |
|
x, 2, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)) |
|
) |
|
channel_att_raw = self.mlp(lp_pool) |
|
elif pool_type == "lse": |
|
|
|
lse_pool = logsumexp_2d(x) |
|
channel_att_raw = self.mlp(lse_pool) |
|
|
|
if channel_att_sum is None: |
|
channel_att_sum = channel_att_raw |
|
else: |
|
channel_att_sum = channel_att_sum + channel_att_raw |
|
|
|
scale = F.sigmoid(channel_att_sum).unsqueeze(2).unsqueeze(3).expand_as(x) |
|
return x * scale |
|
|
|
|
|
def logsumexp_2d(tensor): |
|
tensor_flatten = tensor.view(tensor.size(0), tensor.size(1), -1) |
|
s, _ = torch.max(tensor_flatten, dim=2, keepdim=True) |
|
outputs = s + (tensor_flatten - s).exp().sum(dim=2, keepdim=True).log() |
|
return outputs |
|
|
|
|
|
class ChannelPool(nn.Module): |
|
def forward(self, x): |
|
return torch.cat( |
|
(torch.max(x, 1)[0].unsqueeze(1), torch.mean(x, 1).unsqueeze(1)), dim=1 |
|
) |
|
|
|
|
|
class SpatialGate(nn.Module): |
|
def __init__(self): |
|
super(SpatialGate, self).__init__() |
|
kernel_size = 7 |
|
self.compress = ChannelPool() |
|
self.spatial = BasicConv( |
|
2, 1, kernel_size, stride=1, padding=(kernel_size - 1) // 2, relu=False |
|
) |
|
|
|
def forward(self, x): |
|
x_compress = self.compress(x) |
|
x_out = self.spatial(x_compress) |
|
scale = F.sigmoid(x_out) |
|
return x * scale |
|
|
|
|
|
class CBAM(nn.Module): |
|
def __init__( |
|
self, |
|
gate_channels, |
|
reduction_ratio=16, |
|
pool_types=["avg", "max"], |
|
no_spatial=False, |
|
): |
|
super(CBAM, self).__init__() |
|
self.ChannelGate = ChannelGate(gate_channels, reduction_ratio, pool_types) |
|
self.no_spatial = no_spatial |
|
if not no_spatial: |
|
self.SpatialGate = SpatialGate() |
|
|
|
def forward(self, x): |
|
x_out = self.ChannelGate(x) |
|
if not self.no_spatial: |
|
x_out = self.SpatialGate(x_out) |
|
return x_out |
|
|
|
|
|
def conv(in_planes, out_planes, kernel_size=3, stride=1, padding=1, dilation=1): |
|
return nn.Sequential( |
|
nn.Conv2d( |
|
in_planes, |
|
out_planes, |
|
kernel_size=kernel_size, |
|
stride=stride, |
|
padding=padding, |
|
dilation=dilation, |
|
bias=True, |
|
), |
|
nn.PReLU(out_planes), |
|
) |
|
|
|
|
|
class UNetConv(nn.Module): |
|
def __init__(self, in_planes, out_planes, att=True): |
|
super(UNetConv, self).__init__() |
|
self.conv1 = conv(in_planes, out_planes, 3, 2, 1) |
|
self.conv2 = conv(out_planes, out_planes, 3, 1, 1) |
|
|
|
if att: |
|
self.cbam = CBAM(out_planes, 16) |
|
else: |
|
self.cbam = None |
|
|
|
def forward(self, x): |
|
x = self.conv1(x) |
|
x = self.conv2(x) |
|
if self.cbam is not None: |
|
x = self.cbam(x) |
|
return x |
|
|
|
|
|
class UpConv(nn.Module): |
|
def __init__(self, in_planes, out_planes, att=True): |
|
super(UpConv, self).__init__() |
|
self.deconv = nn.Sequential( |
|
nn.ConvTranspose2d(in_planes, in_planes // 2, 4, 2, 1), |
|
nn.PReLU(in_planes // 2), |
|
) |
|
|
|
|
|
self.conv1 = conv(in_planes, in_planes // 2, 3, 1, 1) |
|
self.conv2 = conv(in_planes // 2, out_planes, 3, 1, 1) |
|
|
|
if att: |
|
self.cbam = CBAM(out_planes, 16) |
|
else: |
|
self.cbam = None |
|
|
|
def forward(self, x1, x2): |
|
x1 = self.deconv(x1) |
|
y = self.conv1(torch.cat((x1, x2), 1)) |
|
y = self.conv2(y) |
|
if self.cbam is not None: |
|
y = self.cbam(y) |
|
return y |
|
|
|
|
|
class FeatureNet(nn.Module): |
|
def __init__(self, in_planes, out_planes): |
|
super(FeatureNet, self).__init__() |
|
|
|
self.conv0 = conv(7, in_planes, 1, 1, 0) |
|
|
|
self.conv1 = UNetConv(in_planes, out_planes // 8, att=False) |
|
self.conv2 = UNetConv(out_planes // 8, out_planes // 4, att=True) |
|
self.conv3 = UNetConv(out_planes // 4, out_planes // 2, att=True) |
|
self.conv4 = UNetConv(out_planes // 2, out_planes, att=True) |
|
self.conv5 = UNetConv(out_planes, 2 * out_planes, att=True) |
|
|
|
self.deconv5 = UpConv(2 * out_planes, out_planes, att=True) |
|
self.deconv4 = UpConv(out_planes, out_planes // 2, att=False) |
|
self.deconv3 = UpConv(out_planes // 2, out_planes // 4, att=False) |
|
|
|
def forward(self, x, level=0): |
|
if x.shape[1] != 17: |
|
x = self.conv0(x) |
|
x2 = self.conv1(x) |
|
x4 = self.conv2(x2) |
|
x8 = self.conv3(x4) |
|
x16 = self.conv4(x8) |
|
x32 = self.conv5(x16) |
|
y = self.deconv5(x32, x16) |
|
|
|
|
|
if level != 0: |
|
y = self.deconv4(y, x8) |
|
if level == 2: |
|
y = self.deconv3(y, x4) |
|
return y |
|
|
|
|
|
class IFBlock(nn.Module): |
|
def __init__(self, c=64, level=0): |
|
super(IFBlock, self).__init__() |
|
self.convblock = nn.Sequential( |
|
conv(c, c), |
|
conv(c, c), |
|
conv(c, c), |
|
conv(c, c), |
|
conv(c, c), |
|
conv(c, c), |
|
) |
|
self.flowconv = nn.Conv2d(c, 4, 3, 1, 1) |
|
self.maskconvx16 = nn.Conv2d(c, 16 * 16 * 9, 1, 1, 0) |
|
self.maskconvx8 = nn.Conv2d(c, 8 * 8 * 9, 1, 1, 0) |
|
self.maskconvx4 = nn.Conv2d(c, 4 * 4 * 9, 1, 1, 0) |
|
|
|
self.level = level |
|
assert self.level in [4, 8, 16], "Bitch" |
|
|
|
def mask_conv(self, x): |
|
if self.level == 4: |
|
return self.maskconvx4(x) |
|
if self.level == 8: |
|
return self.maskconvx8(x) |
|
if self.level == 16: |
|
return self.maskconvx16(x) |
|
|
|
def upsample_flow(self, flow, mask): |
|
|
|
N, _, H, W = flow.shape |
|
mask = mask.view(N, 1, 9, self.level, self.level, H, W) |
|
mask = torch.softmax(mask, dim=2) |
|
|
|
up_flow = F.unfold(self.level * flow, [3, 3], padding=1) |
|
up_flow = up_flow.view(N, 4, 9, 1, 1, H, W) |
|
|
|
up_flow = torch.sum(mask * up_flow, dim=2) |
|
up_flow = up_flow.permute(0, 1, 4, 2, 5, 3) |
|
return up_flow.reshape(N, 4, self.level * H, self.level * W) |
|
|
|
def forward(self, x, scale): |
|
x = self.convblock(x) + x |
|
tmp = self.flowconv(x) |
|
up_mask = self.mask_conv(x) |
|
flow_up = self.upsample_flow(tmp, up_mask) |
|
flow = ( |
|
F.interpolate( |
|
flow_up, scale_factor=scale, mode="bilinear", align_corners=False |
|
) |
|
* scale |
|
) |
|
return flow |
|
|
|
|
|
class IFUNet(nn.Module): |
|
def __init__(self): |
|
super(IFUNet, self).__init__() |
|
|
|
self.fmap = FeatureNet(in_planes=17, out_planes=256) |
|
self.block0 = IFBlock(c=256, level=16) |
|
self.block1 = IFBlock(c=128, level=8) |
|
self.block2 = IFBlock(c=64, level=4) |
|
|
|
def forward(self, x, scale=1.0, timestep=0.5, ensemble=True): |
|
channel = x.shape[1] // 2 |
|
img0 = x[:, :channel] |
|
img1 = x[:, channel:] |
|
if not torch.is_tensor(timestep): |
|
timestep = (x[:, :1].clone() * 0 + 1) * timestep |
|
else: |
|
timestep = timestep.repeat(1, 1, img0.shape[2], img0.shape[3]) |
|
warped_img0 = img0 |
|
warped_img1 = img1 |
|
flow = None |
|
block = [self.block0, self.block1, self.block2] |
|
for i in range(3): |
|
if flow != None: |
|
x = torch.cat((img0, img1, timestep, warped_img0, warped_img1), 1) |
|
flowtmp = flow |
|
if scale != 1: |
|
x = F.interpolate( |
|
x, scale_factor=scale, mode="bilinear", align_corners=False |
|
) |
|
flowtmp = ( |
|
F.interpolate( |
|
flow, |
|
scale_factor=scale, |
|
mode="bilinear", |
|
align_corners=False, |
|
) |
|
* scale |
|
) |
|
x = torch.cat((x, flowtmp), 1) |
|
|
|
Fmap = self.fmap(x, level=i) |
|
flow_d = block[i](Fmap, scale=1.0 / scale) |
|
flow = flow + flow_d |
|
|
|
if ensemble: |
|
x = torch.cat( |
|
(img1, img0, 1 - timestep, warped_img0, warped_img1), 1 |
|
) |
|
flowtmp = flow |
|
if scale != 1: |
|
x = F.interpolate( |
|
x, scale_factor=scale, mode="bilinear", align_corners=False |
|
) |
|
flowtmp = ( |
|
F.interpolate( |
|
flow, |
|
scale_factor=scale, |
|
mode="bilinear", |
|
align_corners=False, |
|
) |
|
* scale |
|
) |
|
x = torch.cat((x, flowtmp), 1) |
|
|
|
Fmap = self.fmap(x, level=i) |
|
flow_d = block[i](Fmap, scale=1.0 / scale) |
|
flow2 = flow + flow_d |
|
flow = (flow + flow2) / 2 |
|
else: |
|
x = torch.cat((img0, img1, timestep), 1) |
|
if scale != 1: |
|
x = F.interpolate( |
|
x, scale_factor=scale, mode="bilinear", align_corners=False |
|
) |
|
Fmap = self.fmap(x, level=i) |
|
flow = block[i](Fmap, scale=1.0 / scale) |
|
|
|
if ensemble: |
|
x = torch.cat((img1, img0, 1 - timestep), 1) |
|
if scale != 1: |
|
x = F.interpolate( |
|
x, scale_factor=scale, mode="bilinear", align_corners=False |
|
) |
|
Fmap = self.fmap(x, level=i) |
|
flow2 = block[i](Fmap, scale=1.0 / scale) |
|
flow = (flow + flow2) / 2 |
|
|
|
warped_img0 = warp(img0, flow[:, :2]) |
|
warped_img1 = warp(img1, flow[:, 2:4]) |
|
return flow, warped_img0, warped_img1 |
|
|
|
|
|
class IFUNetModel(nn.Module): |
|
def __init__(self, local_rank=-1): |
|
super(IFUNetModel, self).__init__() |
|
self.flownet = IFUNet() |
|
self.fusionnet = RRDBNet() |
|
self.refinenet = ResynNet() |
|
|
|
def forward(self, img0, img1, timestep=0.5, scale=1.0, ensemble=False): |
|
n, c, h, w = img0.shape |
|
ph = ((h - 1) // 64 + 1) * 64 |
|
pw = ((w - 1) // 64 + 1) * 64 |
|
padding = (0, pw - w, 0, ph - h) |
|
img0 = F.pad(img0, padding) |
|
img1 = F.pad(img1, padding) |
|
|
|
imgs = torch.cat((img0, img1), 1) |
|
flow, warped_img0, warped_img1 = self.flownet(imgs, scale, timestep, ensemble) |
|
mask = self.fusionnet(img0, img1, warped_img0, warped_img1, flow) |
|
merged = warped_img0 * mask + warped_img1 * (1 - mask) |
|
merged, _ = self.refinenet(imgs, deg=merged, scale=[4, 2, 1]) |
|
return merged[:, :, :h, :w] |
|
|