Spaces:
Sleeping
Sleeping
import math | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from deepinv.models.unet import BFBatchNorm2d | |
from deepinv.physics.blur import gaussian_blur | |
from deepinv.physics.functional import conv2d | |
from deepinv.utils import TensorList | |
from timm.models.layers import trunc_normal_, DropPath | |
def normalize(x, dim=None, eps=1e-4): | |
if dim is None: | |
dim = list(range(1, x.ndim)) | |
norm = torch.linalg.vector_norm(x, dim=dim, keepdim=True, dtype=torch.float32) | |
norm = torch.add(eps, norm, alpha=np.sqrt(norm.numel() / x.numel())) | |
return x / norm.to(x.dtype) | |
class TimestepEmbedding(nn.Module): | |
def __init__(self, hidden_size, frequency_embedding_size=256): | |
super().__init__() | |
self.mlp = nn.Sequential( | |
nn.Linear(frequency_embedding_size, hidden_size), | |
nn.SiLU(), | |
nn.Linear(hidden_size, hidden_size), | |
) | |
self.frequency_embedding_size = frequency_embedding_size | |
def timestep_embedding(t, dim, max_period=10000): | |
half = dim // 2 | |
freqs = torch.exp( | |
-math.log(max_period) * torch.arange(start=0, end=half) / half | |
).to(t.device) | |
args = t[:, None] * freqs[None] | |
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) | |
if dim % 2: | |
embedding = torch.cat( | |
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1 | |
) | |
return embedding | |
def forward(self, t): | |
t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to( | |
dtype=next(self.parameters()).dtype | |
) | |
t_emb = self.mlp(t_freq) | |
return t_emb | |
class MPConv(torch.nn.Module): | |
def __init__(self, in_channels, out_channels, kernel): | |
super().__init__() | |
self.out_channels = out_channels | |
self.weight = torch.nn.Parameter(torch.randn(out_channels, in_channels, *kernel)) | |
def forward(self, x, gain=1): | |
w = self.weight.to(torch.float32) | |
if self.training: | |
with torch.no_grad(): | |
self.weight.copy_(normalize(w)) # forced weight normalization | |
w = normalize(w) # traditional weight normalization | |
w = w * (gain / np.sqrt(w[0].numel())) # magnitude-preserving scaling | |
w = w.to(x.dtype) | |
if w.ndim == 2: | |
return x @ w.t() | |
assert w.ndim == 4 | |
return F.conv2d(x, w, padding=(w.shape[-1] // 2,)) | |
# -------------------------------------------------------------------------------------- | |
def mp_silu(x): | |
return torch.nn.functional.silu(x) / 0.596 | |
class MPFourier(torch.nn.Module): | |
def __init__(self, num_channels, bandwidth=1, device="cpu"): | |
super().__init__() | |
self.register_buffer( | |
"freqs", 2 * np.pi * torch.rand(num_channels, device=device) * bandwidth | |
) | |
self.register_buffer( | |
"phases", 2 * np.pi * torch.rand(num_channels, device=device) | |
) | |
def forward(self, x): | |
y = x.to(torch.float32) | |
y = y.ger(self.freqs.to(torch.float32)) | |
y = y + self.phases.to(torch.float32) | |
y = y.cos() * np.sqrt(2) | |
return y.to(x.dtype) | |
class NoiseEmbedding(torch.nn.Module): | |
def __init__(self, num_channels=1, emb_channels=512, device="cpu", biasfree=True): | |
super().__init__() | |
self.emb_fourier = MPFourier(num_channels, device=device) | |
self.emb_noise = MPConv(num_channels, emb_channels, kernel=[]) | |
self.biasfree = biasfree | |
def forward(self, y, physics, factor): | |
if hasattr(physics, "noise_model") and not callable(physics.noise_model): | |
sigma = getattr(physics.noise_model, "sigma", 0.0) | |
else: | |
sigma = 0.0 | |
if isinstance(y, TensorList): | |
sigma = sigma / (y[0].abs().reshape(y[0].size(0),-1).mean(1) + 1e-8) / factor | |
else: | |
sigma = sigma / (y.abs().reshape(y.size(0),-1).mean(1) + 1e-8) / factor | |
emb_four = self.emb_fourier(sigma) | |
emb = self.emb_noise(emb_four) | |
if self.biasfree: | |
emb = F.relu(emb) | |
else: | |
emb = mp_silu(emb) | |
return emb.unsqueeze(-1).unsqueeze(-1) | |
# -------------------------------------------------------------------------------------- | |
class AffineConv2d(nn.Conv2d): | |
def __init__( | |
self, | |
in_channels, | |
out_channels, | |
kernel_size, | |
mode="affine", | |
bias=False, | |
stride=1, | |
padding=0, | |
dilation=1, | |
groups=1, | |
padding_mode="circular", | |
blind=True, | |
): | |
if mode == "affine": # f(a*x + 1) = a*f(x) + 1 | |
bias = False | |
super().__init__( | |
in_channels, | |
out_channels, | |
kernel_size, | |
bias=bias, | |
stride=stride, | |
padding=padding, | |
dilation=dilation, | |
groups=groups, | |
padding_mode=padding_mode, | |
) | |
self.blind = blind | |
self.mode = mode | |
def affine(self, w): | |
"""returns new kernels that encode affine combinations""" | |
return ( | |
w.view(self.out_channels, -1).roll(1, 1).view(w.size()) | |
- w | |
+ 1 / w[0, ...].numel() | |
) | |
def forward(self, x): | |
if self.mode != "affine": | |
return super().forward(x) | |
else: | |
kernel = ( | |
self.affine(self.weight) | |
if self.blind | |
else torch.cat( | |
(self.affine(self.weight[:, :-1, :, :]), self.weight[:, -1:, :, :]), | |
dim=1, | |
) | |
) | |
padding = tuple( | |
elt for elt in reversed(self.padding) for _ in range(2) | |
) # used to translate padding arg used by Conv module to the ones used by F.pad | |
padding_mode = ( | |
self.padding_mode if self.padding_mode != "zeros" else "constant" | |
) # used to translate padding_mode arg used by Conv module to the ones used by F.pad | |
return F.conv2d( | |
F.pad(x, padding, mode=padding_mode), | |
kernel, | |
stride=self.stride, | |
dilation=self.dilation, | |
groups=self.groups, | |
) | |
# -------------------------------------------------------------------------------------- | |
def kaiser_window(beta, length): | |
"""Return the Kaiser window of length `length` and shape parameter `beta`.""" | |
if beta < 0: | |
raise ValueError("beta must be greater than 0") | |
if length < 1: | |
raise ValueError("length must be greater than 0") | |
if length == 1: | |
return torch.tensor([1.0]) | |
half = (length - 1) / 2 | |
n = torch.arange(length) | |
beta = torch.tensor(beta) | |
return torch.i0(beta * torch.sqrt(1 - ((n - half) / half) ** 2)) / torch.i0(beta) | |
def sinc_filter(factor=2, length=11, windowed=True): | |
r""" | |
Anti-aliasing sinc filter multiplied by a Kaiser window. | |
:param float factor: Downsampling factor. | |
:param int length: Length of the filter. | |
""" | |
deltaf = 1 / factor | |
n = torch.arange(length) - (length - 1) / 2 | |
filter = torch.sinc(n / factor) | |
if windowed: | |
A = 2.285 * (length - 1) * 3.14 * deltaf + 7.95 | |
if A <= 21: | |
beta = 0 | |
elif A <= 50: | |
beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21) | |
else: | |
beta = 0.1102 * (A - 8.7) | |
filter = filter * kaiser_window(beta, length) | |
filter = filter.unsqueeze(0) | |
filter = filter * filter.T | |
filter = filter.unsqueeze(0).unsqueeze(0) | |
filter = filter / filter.sum() | |
return filter | |
class EquivMaxPool(nn.Module): | |
r""" | |
Max pooling layer that is equivariant to translations. | |
:param int kernel_size: size of the pooling window. | |
:param int stride: stride of the pooling operation. | |
:param int padding: padding to apply before pooling. | |
:param bool circular_padding: circular padding for the convolutional layers. | |
""" | |
def __init__( | |
self, | |
antialias="gaussian", | |
factor=2, | |
device="cuda", | |
in_channels=64, | |
out_channels=64, | |
bias=False, | |
padding_mode="circular", | |
): | |
super(EquivMaxPool, self).__init__() | |
self.antialias = antialias | |
if antialias == "gaussian": | |
self.antialias_kernel = gaussian_blur(factor / 3.14).to(device) | |
elif antialias == "sinc": | |
self.antialias_kernel = sinc_filter( | |
factor=factor, length=11, windowed=True | |
).to(device) | |
self.conv_down = AffineConv2d( | |
in_channels, | |
out_channels, | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
bias=bias, | |
padding_mode=padding_mode, | |
groups=1, | |
) | |
self.conv_up = AffineConv2d( | |
out_channels, | |
in_channels, | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
bias=bias, | |
padding_mode=padding_mode, | |
groups=1, | |
) | |
def forward(self, x): | |
return self.downscale(x) | |
def downscale(self, x): | |
r""" | |
Apply the equivariant pooling. | |
:param torch.Tensor x: input tensor. | |
""" | |
B, C, H, W = x.shape | |
x = self.conv_down(x) | |
if self.antialias == "gaussian" or self.antialias == "sinc": | |
x = conv2d(x, self.antialias_kernel, padding="circular") | |
x1 = x[:, :, ::2, ::2].unsqueeze(0) | |
x2 = x[:, :, ::2, 1::2].unsqueeze(0) | |
x3 = x[:, :, 1::2, ::2].unsqueeze(0) | |
x4 = x[:, :, 1::2, 1::2].unsqueeze(0) | |
out = torch.cat([x1, x2, x3, x4], dim=0) # (4, B, C, H/2, W/2) | |
ind = torch.norm(out, dim=(2, 3, 4), p=2) # (4, B) | |
ind = torch.argmax(ind, dim=0) # (B) | |
out = out[ind, torch.arange(B), ...] # (B, C, H/2, W/2) | |
self.ind = ind | |
return out | |
def upscale(self, x): | |
B, C, H, W = x.shape | |
out = torch.zeros((B, C, H * 2, W * 2), device=x.device) | |
out[:, :, ::2, ::2] = x | |
ind = self.ind | |
filter = torch.zeros((B, 1, 2, 2), device=x.device) | |
filter[ind == 0, :, 0, 0] = 1 | |
filter[ind == 1, :, 0, 1] = 1 | |
filter[ind == 2, :, 1, 0] = 1 | |
filter[ind == 3, :, 1, 1] = 1 | |
out = conv2d(out, filter, padding="constant") | |
if self.antialias == "gaussian" or self.antialias == "sinc": | |
out = conv2d(out, self.antialias_kernel, padding="circular") | |
out = self.conv_up(out) | |
return out | |
# -------------------------------------------------------------------------------------- | |
class ConvNextBaseBlock(nn.Module): | |
r""" | |
ConvNeXt Block mimicking DRUNet base layer (Conv + Relu + Conv) | |
Args: | |
in_channels (int): Number of input channels. | |
out_channels (int): Number of output channels. | |
mode (str): Mode for the AffineConv2d (if needed, else ignored). | |
bias (bool): Whether to use bias in convolutions. Default: False. | |
ksize (int): Kernel size for the convolutions. Default: 7. | |
padding_mode (str): Padding mode for convolutions. Default: 'circular'. | |
mult_fact (int): Multiplier factor for expanding the number of channels. | |
residual (bool): Whether to use a residual connection. Default: False. | |
""" | |
def __init__( | |
self, | |
in_channels, | |
out_channels, | |
mode="", | |
bias=False, | |
ksize=7, | |
padding_mode="circular", | |
mult_fact=1, | |
residual=False, | |
): | |
super().__init__() | |
### DEPTHWISE SEPARABLE CONVOLUTION: (N,C,H,W) -> (N,4*C,H,W) | |
# depthwise conv with big kernel | |
self.dwconv_a = AffineConv2d( | |
in_channels, | |
in_channels, | |
kernel_size=ksize, | |
padding=ksize // 2, | |
groups=in_channels, | |
padding_mode=padding_mode, | |
bias=bias, | |
mode=mode, | |
) | |
# depthwise conv with small kernel | |
self.dwconv_a_small = AffineConv2d( | |
in_channels, | |
in_channels, | |
kernel_size=3, | |
padding=3 // 2, | |
groups=in_channels, | |
padding_mode=padding_mode, | |
bias=bias, | |
mode=mode, | |
) | |
# pointwise conv to change number of channels | |
self.pwconv_a1 = AffineConv2d( | |
in_channels, | |
mult_fact * in_channels, | |
kernel_size=1, | |
stride=1, | |
padding=0, | |
mode=mode, | |
bias=bias, | |
padding_mode=padding_mode, | |
groups=1, | |
) | |
### ACTIVATION | |
self.act_a = nn.ReLU() | |
### POINTWISE CONVOLUTION: (N,4*C,H,W) -> (N,O,H,W) | |
self.pwconv_a2 = AffineConv2d( | |
mult_fact * in_channels, | |
out_channels, | |
kernel_size=1, | |
stride=1, | |
padding=0, | |
bias=bias, | |
padding_mode=padding_mode, | |
groups=1, | |
) | |
### Needed to match the number of channels : (N,C,H,W) -> (C,O,H,W) | |
self.residual = residual | |
if self.residual: | |
self.residual_conv = AffineConv2d( | |
in_channels, | |
out_channels, | |
kernel_size=1, | |
stride=1, | |
padding=0, | |
groups=1, | |
padding_mode=padding_mode, | |
bias=bias, | |
mode=mode, | |
) | |
def forward(self, x_in, stream1=None, stream2=None): | |
"""Forward with GPU parallelization using multiple cuda streams.""" | |
if stream1 is not None and stream2 is not None: | |
# Use the streams | |
with torch.cuda.stream(stream1): | |
output_a = self.dwconv_a(x_in) # Run the first convolution in stream1 | |
with torch.cuda.stream(stream2): | |
output_a_small = self.dwconv_a_small( | |
x_in | |
) # Run the second convolution in stream2 | |
# Ensure the streams are synchronized before adding the results | |
torch.cuda.synchronize() | |
x = self.pwconv_a(output_a + output_a_small) | |
else: | |
x = self.dwconv_a(x_in) + self.dwconv_a_small(x_in) # replk 7x7 with 3x3 | |
x = self.pwconv_a1(x) | |
x = self.act_a(x) | |
x = self.pwconv_a2(x) # (N,O,H,W) | |
if self.residual: | |
x = self.residual_conv(x_in) + x | |
return x | |
class ConvNextBlock2(nn.Module): | |
r""" | |
ConvNeXt Block mimicking DRUNet base layer (Conv + Relu + Conv) | |
Args: | |
??? | |
""" | |
def __init__( | |
self, | |
in_channels, | |
out_channels, | |
mode="affine", | |
bias=False, | |
ksize=7, | |
padding_mode="circular", | |
mult_fact=4, | |
s1=None, | |
s2=None, | |
): | |
super().__init__() | |
self.block_0 = ConvNextBaseBlock( | |
in_channels, | |
out_channels, | |
mode=mode, | |
bias=bias, | |
ksize=ksize, | |
padding_mode=padding_mode, | |
mult_fact=mult_fact, | |
) | |
self.block_1 = ConvNextBaseBlock( | |
in_channels, | |
out_channels, | |
mode=mode, | |
bias=bias, | |
ksize=ksize, | |
padding_mode=padding_mode, | |
mult_fact=mult_fact, | |
) | |
# self.relu = nn.ReLU(inplace=True) # issue with the network when working in FP16 ??? | |
self.relu = nn.ReLU() | |
# cuda stream to parallelize execution of ConvNextBaseBlock | |
self.s1 = s1 | |
self.s2 = s2 | |
def forward(self, input, emb_sigma=None): | |
if self.s1 is not None and self.s2 is not None: | |
x = self.block_0(input, self.s1, self.s2) | |
else: | |
x = self.block_0(input) | |
x = self.relu(x) | |
if self.s1 is not None and self.s2 is not None: | |
x = self.block_1(x, self.s1, self.s2) | |
else: | |
x = self.block_1(x) | |
return x + input | |
class CondResBlock(nn.Module): | |
def __init__( | |
self, | |
in_channels=64, | |
out_channels=64, | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
bias=False, | |
emb_channels=512, | |
): | |
super(CondResBlock, self).__init__() | |
assert in_channels == out_channels, "Only support in_channels==out_channels." | |
self.gain = torch.nn.Parameter(torch.tensor([1.0]), requires_grad=True) | |
self.emb_linear = MPConv(emb_channels, out_channels, kernel=[3, 3]) | |
self.conv1 = nn.Conv2d( | |
in_channels, out_channels, kernel_size, stride, padding, bias=bias | |
) | |
self.conv2 = nn.Conv2d( | |
out_channels, out_channels, kernel_size, stride, padding, bias=bias | |
) | |
def forward(self, x, emb_sigma): | |
# u = self.conv1(mp_silu(x)) | |
u = self.conv1(F.relu((x))) | |
c = self.emb_linear(emb_sigma, gain=self.gain) + 1 | |
# y = mp_silu(u * c.unsqueeze(2).unsqueeze(3).to(u.dtype)) | |
y = F.relu(u * c.unsqueeze(2).unsqueeze(3).to(u.dtype)) | |
y = self.conv2(y) | |
return x + y | |
""" | |
Functional blocks below | |
""" | |
from collections import OrderedDict | |
import torch | |
import torch.nn as nn | |
""" | |
# -------------------------------------------- | |
# Advanced nn.Sequential | |
# https://github.com/xinntao/BasicSR | |
# -------------------------------------------- | |
""" | |
def sequential(*args): | |
"""Advanced nn.Sequential. | |
Args: | |
nn.Sequential, nn.Module | |
Returns: | |
nn.Sequential | |
""" | |
if len(args) == 1: | |
if isinstance(args[0], OrderedDict): | |
raise NotImplementedError("sequential does not support OrderedDict input.") | |
return args[0] # No sequential is needed. | |
modules = [] | |
for module in args: | |
if isinstance(module, nn.Sequential): | |
for submodule in module.children(): | |
modules.append(submodule) | |
elif isinstance(module, nn.Module): | |
modules.append(module) | |
return nn.Sequential(*modules) | |
""" | |
# -------------------------------------------- | |
# Useful blocks | |
# https://github.com/xinntao/BasicSR | |
# -------------------------------- | |
# conv + normaliation + relu (conv) | |
# (PixelUnShuffle) | |
# (ConditionalBatchNorm2d) | |
# concat (ConcatBlock) | |
# sum (ShortcutBlock) | |
# resblock (ResBlock) | |
# Channel Attention (CA) Layer (CALayer) | |
# Residual Channel Attention Block (RCABlock) | |
# Residual Channel Attention Group (RCAGroup) | |
# Residual Dense Block (ResidualDenseBlock_5C) | |
# Residual in Residual Dense Block (RRDB) | |
# -------------------------------------------- | |
""" | |
# -------------------------------------------- | |
# return nn.Sequantial of (Conv + BN + ReLU) | |
# -------------------------------------------- | |
def conv( | |
in_channels=64, | |
out_channels=64, | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
bias=True, | |
mode="CBR", | |
negative_slope=0.2, | |
): | |
L = [] | |
for t in mode: | |
if t == "C": | |
L.append( | |
nn.Conv2d( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
kernel_size=kernel_size, | |
stride=stride, | |
padding=padding, | |
bias=bias, | |
) | |
) | |
elif t == "T": | |
L.append( | |
nn.ConvTranspose2d( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
kernel_size=kernel_size, | |
stride=stride, | |
padding=padding, | |
bias=bias, | |
) | |
) | |
elif t == "B": | |
L.append(nn.BatchNorm2d(out_channels, momentum=0.9, eps=1e-04, affine=True)) | |
elif t == "I": | |
L.append(nn.InstanceNorm2d(out_channels, affine=True)) | |
elif t == "R": | |
L.append(nn.ReLU(inplace=True)) | |
elif t == "r": | |
L.append(nn.ReLU(inplace=False)) | |
elif t == "L": | |
L.append(nn.LeakyReLU(negative_slope=negative_slope, inplace=True)) | |
elif t == "l": | |
L.append(nn.LeakyReLU(negative_slope=negative_slope, inplace=False)) | |
elif t == "E": | |
L.append(nn.ELU(inplace=False)) | |
elif t == "s": | |
L.append(nn.Softplus()) | |
elif t == "2": | |
L.append(nn.PixelShuffle(upscale_factor=2)) | |
elif t == "3": | |
L.append(nn.PixelShuffle(upscale_factor=3)) | |
elif t == "4": | |
L.append(nn.PixelShuffle(upscale_factor=4)) | |
elif t == "U": | |
L.append(nn.Upsample(scale_factor=2, mode="nearest")) | |
elif t == "u": | |
L.append(nn.Upsample(scale_factor=3, mode="nearest")) | |
elif t == "v": | |
L.append(nn.Upsample(scale_factor=4, mode="nearest")) | |
elif t == "M": | |
L.append(nn.MaxPool2d(kernel_size=kernel_size, stride=stride, padding=0)) | |
elif t == "A": | |
L.append(nn.AvgPool2d(kernel_size=kernel_size, stride=stride, padding=0)) | |
else: | |
raise NotImplementedError("Undefined type: ".format(t)) | |
return sequential(*L) | |
""" | |
# -------------------------------------------- | |
# Upsampler | |
# Kai Zhang, https://github.com/cszn/KAIR | |
# -------------------------------------------- | |
# upsample_pixelshuffle | |
# upsample_upconv | |
# upsample_convtranspose | |
# -------------------------------------------- | |
""" | |
# -------------------------------------------- | |
# conv + subp (+ relu) | |
# -------------------------------------------- | |
def upsample_pixelshuffle( | |
in_channels=64, | |
out_channels=3, | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
bias=True, | |
mode="2R", | |
negative_slope=0.2, | |
): | |
assert len(mode) < 4 and mode[0] in [ | |
"2", | |
"3", | |
"4", | |
], "mode examples: 2, 2R, 2BR, 3, ..., 4BR." | |
up1 = conv( | |
in_channels, | |
out_channels * (int(mode[0]) ** 2), | |
kernel_size, | |
stride, | |
padding, | |
bias, | |
mode="C" + mode, | |
negative_slope=negative_slope, | |
) | |
return up1 | |
# -------------------------------------------- | |
# nearest_upsample + conv (+ R) | |
# -------------------------------------------- | |
def upsample_upconv( | |
in_channels=64, | |
out_channels=3, | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
bias=True, | |
mode="2R", | |
negative_slope=0.2, | |
): | |
assert len(mode) < 4 and mode[0] in [ | |
"2", | |
"3", | |
"4", | |
], "mode examples: 2, 2R, 2BR, 3, ..., 4BR" | |
if mode[0] == "2": | |
uc = "UC" | |
elif mode[0] == "3": | |
uc = "uC" | |
elif mode[0] == "4": | |
uc = "vC" | |
mode = mode.replace(mode[0], uc) | |
up1 = conv( | |
in_channels, | |
out_channels, | |
kernel_size, | |
stride, | |
padding, | |
bias, | |
mode=mode, | |
negative_slope=negative_slope, | |
) | |
return up1 | |
# -------------------------------------------- | |
# convTranspose (+ relu) | |
# -------------------------------------------- | |
def upsample_convtranspose( | |
in_channels=64, | |
out_channels=3, | |
kernel_size=2, | |
stride=2, | |
padding=0, | |
bias=True, | |
mode="2R", | |
negative_slope=0.2, | |
): | |
assert len(mode) < 4 and mode[0] in [ | |
"2", | |
"3", | |
"4", | |
"8", | |
], "mode examples: 2, 2R, 2BR, 3, ..., 4BR." | |
kernel_size = int(mode[0]) | |
stride = int(mode[0]) | |
mode = mode.replace(mode[0], "T") | |
up1 = conv( | |
in_channels, | |
out_channels, | |
kernel_size, | |
stride, | |
padding, | |
bias, | |
mode, | |
negative_slope, | |
) | |
return up1 | |
""" | |
# -------------------------------------------- | |
# Downsampler | |
# Kai Zhang, https://github.com/cszn/KAIR | |
# -------------------------------------------- | |
# downsample_strideconv | |
# downsample_maxpool | |
# downsample_avgpool | |
# -------------------------------------------- | |
""" | |
# -------------------------------------------- | |
# strideconv (+ relu) | |
# -------------------------------------------- | |
def downsample_strideconv( | |
in_channels=64, | |
out_channels=64, | |
kernel_size=2, | |
stride=2, | |
padding=0, | |
bias=True, | |
mode="2R", | |
negative_slope=0.2, | |
): | |
assert len(mode) < 4 and mode[0] in [ | |
"2", | |
"3", | |
"4", | |
"8", | |
], "mode examples: 2, 2R, 2BR, 3, ..., 4BR." | |
kernel_size = int(mode[0]) | |
stride = int(mode[0]) | |
mode = mode.replace(mode[0], "C") | |
down1 = conv( | |
in_channels, | |
out_channels, | |
kernel_size, | |
stride, | |
padding, | |
bias, | |
mode, | |
negative_slope, | |
) | |
return down1 | |
# -------------------------------------------- | |
# maxpooling + conv (+ relu) | |
# -------------------------------------------- | |
def downsample_maxpool( | |
in_channels=64, | |
out_channels=64, | |
kernel_size=3, | |
stride=1, | |
padding=0, | |
bias=True, | |
mode="2R", | |
negative_slope=0.2, | |
): | |
assert len(mode) < 4 and mode[0] in [ | |
"2", | |
"3", | |
], "mode examples: 2, 2R, 2BR, 3, ..., 3BR." | |
kernel_size_pool = int(mode[0]) | |
stride_pool = int(mode[0]) | |
mode = mode.replace(mode[0], "MC") | |
pool = conv( | |
kernel_size=kernel_size_pool, | |
stride=stride_pool, | |
mode=mode[0], | |
negative_slope=negative_slope, | |
) | |
pool_tail = conv( | |
in_channels, | |
out_channels, | |
kernel_size, | |
stride, | |
padding, | |
bias, | |
mode=mode[1:], | |
negative_slope=negative_slope, | |
) | |
return sequential(pool, pool_tail) | |
# -------------------------------------------- | |
# averagepooling + conv (+ relu) | |
# -------------------------------------------- | |
def downsample_avgpool( | |
in_channels=64, | |
out_channels=64, | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
bias=True, | |
mode="2R", | |
negative_slope=0.2, | |
): | |
assert len(mode) < 4 and mode[0] in [ | |
"2", | |
"3", | |
], "mode examples: 2, 2R, 2BR, 3, ..., 3BR." | |
kernel_size_pool = int(mode[0]) | |
stride_pool = int(mode[0]) | |
mode = mode.replace(mode[0], "AC") | |
pool = conv( | |
kernel_size=kernel_size_pool, | |
stride=stride_pool, | |
mode=mode[0], | |
negative_slope=negative_slope, | |
) | |
pool_tail = conv( | |
in_channels, | |
out_channels, | |
kernel_size, | |
stride, | |
padding, | |
bias, | |
mode=mode[1:], | |
negative_slope=negative_slope, | |
) | |
return sequential(pool, pool_tail) |