|
from torch import nn |
|
import torch.nn.functional as F |
|
import torch |
|
|
|
|
|
class AntiAliasInterpolation2d(nn.Module): |
|
""" |
|
Band-limited downsampling, for better preservation of the input signal. |
|
""" |
|
|
|
def __init__(self, channels, scale): |
|
super(AntiAliasInterpolation2d, self).__init__() |
|
sigma = (1 / scale - 1) / 2 |
|
kernel_size = 2 * round(sigma * 4) + 1 |
|
self.ka = kernel_size // 2 |
|
self.kb = self.ka - 1 if kernel_size % 2 == 0 else self.ka |
|
|
|
kernel_size = [kernel_size, kernel_size] |
|
sigma = [sigma, sigma] |
|
|
|
|
|
kernel = 1 |
|
meshgrids = torch.meshgrid( |
|
[ |
|
torch.arange(size, dtype=torch.float32) |
|
for size in kernel_size |
|
] |
|
) |
|
for size, std, mgrid in zip(kernel_size, sigma, meshgrids): |
|
mean = (size - 1) / 2 |
|
kernel *= torch.exp(-(mgrid - mean) ** 2 / (2 * std ** 2)) |
|
|
|
|
|
kernel = kernel / torch.sum(kernel) |
|
|
|
kernel = kernel.view(1, 1, *kernel.size()) |
|
kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1)) |
|
|
|
self.register_buffer('weight', kernel) |
|
self.groups = channels |
|
self.scale = scale |
|
inv_scale = 1 / scale |
|
self.int_inv_scale = int(inv_scale) |
|
|
|
def forward(self, input): |
|
if self.scale == 1.0: |
|
return input |
|
|
|
out = F.pad(input, (self.ka, self.kb, self.ka, self.kb)) |
|
out = F.conv2d(out, weight=self.weight, groups=self.groups) |
|
out = out[:, :, ::self.int_inv_scale, ::self.int_inv_scale] |
|
|
|
return out |
|
|