Spaces:
Sleeping
Sleeping
""" Median Pool | |
Hacked together by / Copyright 2020 Ross Wightman | |
""" | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from .helpers import to_2tuple, to_4tuple | |
class MedianPool2d(nn.Module): | |
""" Median pool (usable as median filter when stride=1) module. | |
Args: | |
kernel_size: size of pooling kernel, int or 2-tuple | |
stride: pool stride, int or 2-tuple | |
padding: pool padding, int or 4-tuple (l, r, t, b) as in pytorch F.pad | |
same: override padding and enforce same padding, boolean | |
""" | |
def __init__(self, kernel_size=3, stride=1, padding=0, same=False): | |
super(MedianPool2d, self).__init__() | |
self.k = to_2tuple(kernel_size) | |
self.stride = to_2tuple(stride) | |
self.padding = to_4tuple(padding) # convert to l, r, t, b | |
self.same = same | |
def _padding(self, x): | |
if self.same: | |
ih, iw = x.size()[2:] | |
if ih % self.stride[0] == 0: | |
ph = max(self.k[0] - self.stride[0], 0) | |
else: | |
ph = max(self.k[0] - (ih % self.stride[0]), 0) | |
if iw % self.stride[1] == 0: | |
pw = max(self.k[1] - self.stride[1], 0) | |
else: | |
pw = max(self.k[1] - (iw % self.stride[1]), 0) | |
pl = pw // 2 | |
pr = pw - pl | |
pt = ph // 2 | |
pb = ph - pt | |
padding = (pl, pr, pt, pb) | |
else: | |
padding = self.padding | |
return padding | |
def forward(self, x): | |
x = F.pad(x, self._padding(x), mode='reflect') | |
x = x.unfold(2, self.k[0], self.stride[0]).unfold(3, self.k[1], self.stride[1]) | |
x = x.contiguous().view(x.size()[:4] + (-1,)).median(dim=-1)[0] | |
return x | |