denoising / models /blocks.py
Yonuts's picture
Bugfix
33dc149
raw
history blame
26.5 kB
import math
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from deepinv.models.unet import BFBatchNorm2d
from deepinv.physics.blur import gaussian_blur
from deepinv.physics.functional import conv2d
from deepinv.utils import TensorList
from timm.models.layers import trunc_normal_, DropPath
def normalize(x, dim=None, eps=1e-4):
if dim is None:
dim = list(range(1, x.ndim))
norm = torch.linalg.vector_norm(x, dim=dim, keepdim=True, dtype=torch.float32)
norm = torch.add(eps, norm, alpha=np.sqrt(norm.numel() / x.numel()))
return x / norm.to(x.dtype)
class TimestepEmbedding(nn.Module):
def __init__(self, hidden_size, frequency_embedding_size=256):
super().__init__()
self.mlp = nn.Sequential(
nn.Linear(frequency_embedding_size, hidden_size),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size),
)
self.frequency_embedding_size = frequency_embedding_size
@staticmethod
def timestep_embedding(t, dim, max_period=10000):
half = dim // 2
freqs = torch.exp(
-math.log(max_period) * torch.arange(start=0, end=half) / half
).to(t.device)
args = t[:, None] * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat(
[embedding, torch.zeros_like(embedding[:, :1])], dim=-1
)
return embedding
def forward(self, t):
t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to(
dtype=next(self.parameters()).dtype
)
t_emb = self.mlp(t_freq)
return t_emb
class MPConv(torch.nn.Module):
def __init__(self, in_channels, out_channels, kernel):
super().__init__()
self.out_channels = out_channels
self.weight = torch.nn.Parameter(torch.randn(out_channels, in_channels, *kernel))
def forward(self, x, gain=1):
w = self.weight.to(torch.float32)
if self.training:
with torch.no_grad():
self.weight.copy_(normalize(w)) # forced weight normalization
w = normalize(w) # traditional weight normalization
w = w * (gain / np.sqrt(w[0].numel())) # magnitude-preserving scaling
w = w.to(x.dtype)
if w.ndim == 2:
return x @ w.t()
assert w.ndim == 4
return F.conv2d(x, w, padding=(w.shape[-1] // 2,))
# --------------------------------------------------------------------------------------
def mp_silu(x):
return torch.nn.functional.silu(x) / 0.596
class MPFourier(torch.nn.Module):
def __init__(self, num_channels, bandwidth=1, device="cpu"):
super().__init__()
self.register_buffer(
"freqs", 2 * np.pi * torch.rand(num_channels, device=device) * bandwidth
)
self.register_buffer(
"phases", 2 * np.pi * torch.rand(num_channels, device=device)
)
def forward(self, x):
y = x.to(torch.float32)
y = y.ger(self.freqs.to(torch.float32))
y = y + self.phases.to(torch.float32)
y = y.cos() * np.sqrt(2)
return y.to(x.dtype)
class NoiseEmbedding(torch.nn.Module):
def __init__(self, num_channels=1, emb_channels=512, device="cpu", biasfree=True):
super().__init__()
self.emb_fourier = MPFourier(num_channels, device=device)
self.emb_noise = MPConv(num_channels, emb_channels, kernel=[])
self.biasfree = biasfree
def forward(self, y, physics, factor):
if hasattr(physics, "noise_model") and not callable(physics.noise_model):
sigma = getattr(physics.noise_model, "sigma", 0.0)
else:
sigma = 0.0
if isinstance(y, TensorList):
sigma = sigma / (y[0].abs().reshape(y[0].size(0),-1).mean(1) + 1e-8) / factor
else:
sigma = sigma / (y.abs().reshape(y.size(0),-1).mean(1) + 1e-8) / factor
emb_four = self.emb_fourier(sigma)
emb = self.emb_noise(emb_four)
if self.biasfree:
emb = F.relu(emb)
else:
emb = mp_silu(emb)
return emb.unsqueeze(-1).unsqueeze(-1)
# --------------------------------------------------------------------------------------
class AffineConv2d(nn.Conv2d):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
mode="affine",
bias=False,
stride=1,
padding=0,
dilation=1,
groups=1,
padding_mode="circular",
blind=True,
):
if mode == "affine": # f(a*x + 1) = a*f(x) + 1
bias = False
super().__init__(
in_channels,
out_channels,
kernel_size,
bias=bias,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
padding_mode=padding_mode,
)
self.blind = blind
self.mode = mode
def affine(self, w):
"""returns new kernels that encode affine combinations"""
return (
w.view(self.out_channels, -1).roll(1, 1).view(w.size())
- w
+ 1 / w[0, ...].numel()
)
def forward(self, x):
if self.mode != "affine":
return super().forward(x)
else:
kernel = (
self.affine(self.weight)
if self.blind
else torch.cat(
(self.affine(self.weight[:, :-1, :, :]), self.weight[:, -1:, :, :]),
dim=1,
)
)
padding = tuple(
elt for elt in reversed(self.padding) for _ in range(2)
) # used to translate padding arg used by Conv module to the ones used by F.pad
padding_mode = (
self.padding_mode if self.padding_mode != "zeros" else "constant"
) # used to translate padding_mode arg used by Conv module to the ones used by F.pad
return F.conv2d(
F.pad(x, padding, mode=padding_mode),
kernel,
stride=self.stride,
dilation=self.dilation,
groups=self.groups,
)
# --------------------------------------------------------------------------------------
def kaiser_window(beta, length):
"""Return the Kaiser window of length `length` and shape parameter `beta`."""
if beta < 0:
raise ValueError("beta must be greater than 0")
if length < 1:
raise ValueError("length must be greater than 0")
if length == 1:
return torch.tensor([1.0])
half = (length - 1) / 2
n = torch.arange(length)
beta = torch.tensor(beta)
return torch.i0(beta * torch.sqrt(1 - ((n - half) / half) ** 2)) / torch.i0(beta)
def sinc_filter(factor=2, length=11, windowed=True):
r"""
Anti-aliasing sinc filter multiplied by a Kaiser window.
:param float factor: Downsampling factor.
:param int length: Length of the filter.
"""
deltaf = 1 / factor
n = torch.arange(length) - (length - 1) / 2
filter = torch.sinc(n / factor)
if windowed:
A = 2.285 * (length - 1) * 3.14 * deltaf + 7.95
if A <= 21:
beta = 0
elif A <= 50:
beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21)
else:
beta = 0.1102 * (A - 8.7)
filter = filter * kaiser_window(beta, length)
filter = filter.unsqueeze(0)
filter = filter * filter.T
filter = filter.unsqueeze(0).unsqueeze(0)
filter = filter / filter.sum()
return filter
class EquivMaxPool(nn.Module):
r"""
Max pooling layer that is equivariant to translations.
:param int kernel_size: size of the pooling window.
:param int stride: stride of the pooling operation.
:param int padding: padding to apply before pooling.
:param bool circular_padding: circular padding for the convolutional layers.
"""
def __init__(
self,
antialias="gaussian",
factor=2,
device="cuda",
in_channels=64,
out_channels=64,
bias=False,
padding_mode="circular",
):
super(EquivMaxPool, self).__init__()
self.antialias = antialias
if antialias == "gaussian":
self.antialias_kernel = gaussian_blur(factor / 3.14).to(device)
elif antialias == "sinc":
self.antialias_kernel = sinc_filter(
factor=factor, length=11, windowed=True
).to(device)
self.conv_down = AffineConv2d(
in_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1,
bias=bias,
padding_mode=padding_mode,
groups=1,
)
self.conv_up = AffineConv2d(
out_channels,
in_channels,
kernel_size=3,
stride=1,
padding=1,
bias=bias,
padding_mode=padding_mode,
groups=1,
)
def forward(self, x):
return self.downscale(x)
def downscale(self, x):
r"""
Apply the equivariant pooling.
:param torch.Tensor x: input tensor.
"""
B, C, H, W = x.shape
x = self.conv_down(x)
if self.antialias == "gaussian" or self.antialias == "sinc":
x = conv2d(x, self.antialias_kernel, padding="circular")
x1 = x[:, :, ::2, ::2].unsqueeze(0)
x2 = x[:, :, ::2, 1::2].unsqueeze(0)
x3 = x[:, :, 1::2, ::2].unsqueeze(0)
x4 = x[:, :, 1::2, 1::2].unsqueeze(0)
out = torch.cat([x1, x2, x3, x4], dim=0) # (4, B, C, H/2, W/2)
ind = torch.norm(out, dim=(2, 3, 4), p=2) # (4, B)
ind = torch.argmax(ind, dim=0) # (B)
out = out[ind, torch.arange(B), ...] # (B, C, H/2, W/2)
self.ind = ind
return out
def upscale(self, x):
B, C, H, W = x.shape
out = torch.zeros((B, C, H * 2, W * 2), device=x.device)
out[:, :, ::2, ::2] = x
ind = self.ind
filter = torch.zeros((B, 1, 2, 2), device=x.device)
filter[ind == 0, :, 0, 0] = 1
filter[ind == 1, :, 0, 1] = 1
filter[ind == 2, :, 1, 0] = 1
filter[ind == 3, :, 1, 1] = 1
out = conv2d(out, filter, padding="constant")
if self.antialias == "gaussian" or self.antialias == "sinc":
out = conv2d(out, self.antialias_kernel, padding="circular")
out = self.conv_up(out)
return out
# --------------------------------------------------------------------------------------
class ConvNextBaseBlock(nn.Module):
r"""
ConvNeXt Block mimicking DRUNet base layer (Conv + Relu + Conv)
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
mode (str): Mode for the AffineConv2d (if needed, else ignored).
bias (bool): Whether to use bias in convolutions. Default: False.
ksize (int): Kernel size for the convolutions. Default: 7.
padding_mode (str): Padding mode for convolutions. Default: 'circular'.
mult_fact (int): Multiplier factor for expanding the number of channels.
residual (bool): Whether to use a residual connection. Default: False.
"""
def __init__(
self,
in_channels,
out_channels,
mode="",
bias=False,
ksize=7,
padding_mode="circular",
mult_fact=1,
residual=False,
):
super().__init__()
### DEPTHWISE SEPARABLE CONVOLUTION: (N,C,H,W) -> (N,4*C,H,W)
# depthwise conv with big kernel
self.dwconv_a = AffineConv2d(
in_channels,
in_channels,
kernel_size=ksize,
padding=ksize // 2,
groups=in_channels,
padding_mode=padding_mode,
bias=bias,
mode=mode,
)
# depthwise conv with small kernel
self.dwconv_a_small = AffineConv2d(
in_channels,
in_channels,
kernel_size=3,
padding=3 // 2,
groups=in_channels,
padding_mode=padding_mode,
bias=bias,
mode=mode,
)
# pointwise conv to change number of channels
self.pwconv_a1 = AffineConv2d(
in_channels,
mult_fact * in_channels,
kernel_size=1,
stride=1,
padding=0,
mode=mode,
bias=bias,
padding_mode=padding_mode,
groups=1,
)
### ACTIVATION
self.act_a = nn.ReLU()
### POINTWISE CONVOLUTION: (N,4*C,H,W) -> (N,O,H,W)
self.pwconv_a2 = AffineConv2d(
mult_fact * in_channels,
out_channels,
kernel_size=1,
stride=1,
padding=0,
bias=bias,
padding_mode=padding_mode,
groups=1,
)
### Needed to match the number of channels : (N,C,H,W) -> (C,O,H,W)
self.residual = residual
if self.residual:
self.residual_conv = AffineConv2d(
in_channels,
out_channels,
kernel_size=1,
stride=1,
padding=0,
groups=1,
padding_mode=padding_mode,
bias=bias,
mode=mode,
)
def forward(self, x_in, stream1=None, stream2=None):
"""Forward with GPU parallelization using multiple cuda streams."""
if stream1 is not None and stream2 is not None:
# Use the streams
with torch.cuda.stream(stream1):
output_a = self.dwconv_a(x_in) # Run the first convolution in stream1
with torch.cuda.stream(stream2):
output_a_small = self.dwconv_a_small(
x_in
) # Run the second convolution in stream2
# Ensure the streams are synchronized before adding the results
torch.cuda.synchronize()
x = self.pwconv_a(output_a + output_a_small)
else:
x = self.dwconv_a(x_in) + self.dwconv_a_small(x_in) # replk 7x7 with 3x3
x = self.pwconv_a1(x)
x = self.act_a(x)
x = self.pwconv_a2(x) # (N,O,H,W)
if self.residual:
x = self.residual_conv(x_in) + x
return x
class ConvNextBlock2(nn.Module):
r"""
ConvNeXt Block mimicking DRUNet base layer (Conv + Relu + Conv)
Args:
???
"""
def __init__(
self,
in_channels,
out_channels,
mode="affine",
bias=False,
ksize=7,
padding_mode="circular",
mult_fact=4,
s1=None,
s2=None,
):
super().__init__()
self.block_0 = ConvNextBaseBlock(
in_channels,
out_channels,
mode=mode,
bias=bias,
ksize=ksize,
padding_mode=padding_mode,
mult_fact=mult_fact,
)
self.block_1 = ConvNextBaseBlock(
in_channels,
out_channels,
mode=mode,
bias=bias,
ksize=ksize,
padding_mode=padding_mode,
mult_fact=mult_fact,
)
# self.relu = nn.ReLU(inplace=True) # issue with the network when working in FP16 ???
self.relu = nn.ReLU()
# cuda stream to parallelize execution of ConvNextBaseBlock
self.s1 = s1
self.s2 = s2
def forward(self, input, emb_sigma=None):
if self.s1 is not None and self.s2 is not None:
x = self.block_0(input, self.s1, self.s2)
else:
x = self.block_0(input)
x = self.relu(x)
if self.s1 is not None and self.s2 is not None:
x = self.block_1(x, self.s1, self.s2)
else:
x = self.block_1(x)
return x + input
class CondResBlock(nn.Module):
def __init__(
self,
in_channels=64,
out_channels=64,
kernel_size=3,
stride=1,
padding=1,
bias=False,
emb_channels=512,
):
super(CondResBlock, self).__init__()
assert in_channels == out_channels, "Only support in_channels==out_channels."
self.gain = torch.nn.Parameter(torch.tensor([1.0]), requires_grad=True)
self.emb_linear = MPConv(emb_channels, out_channels, kernel=[3, 3])
self.conv1 = nn.Conv2d(
in_channels, out_channels, kernel_size, stride, padding, bias=bias
)
self.conv2 = nn.Conv2d(
out_channels, out_channels, kernel_size, stride, padding, bias=bias
)
def forward(self, x, emb_sigma):
# u = self.conv1(mp_silu(x))
u = self.conv1(F.relu((x)))
c = self.emb_linear(emb_sigma, gain=self.gain) + 1
# y = mp_silu(u * c.unsqueeze(2).unsqueeze(3).to(u.dtype))
y = F.relu(u * c.unsqueeze(2).unsqueeze(3).to(u.dtype))
y = self.conv2(y)
return x + y
"""
Functional blocks below
"""
from collections import OrderedDict
import torch
import torch.nn as nn
"""
# --------------------------------------------
# Advanced nn.Sequential
# https://github.com/xinntao/BasicSR
# --------------------------------------------
"""
def sequential(*args):
"""Advanced nn.Sequential.
Args:
nn.Sequential, nn.Module
Returns:
nn.Sequential
"""
if len(args) == 1:
if isinstance(args[0], OrderedDict):
raise NotImplementedError("sequential does not support OrderedDict input.")
return args[0] # No sequential is needed.
modules = []
for module in args:
if isinstance(module, nn.Sequential):
for submodule in module.children():
modules.append(submodule)
elif isinstance(module, nn.Module):
modules.append(module)
return nn.Sequential(*modules)
"""
# --------------------------------------------
# Useful blocks
# https://github.com/xinntao/BasicSR
# --------------------------------
# conv + normaliation + relu (conv)
# (PixelUnShuffle)
# (ConditionalBatchNorm2d)
# concat (ConcatBlock)
# sum (ShortcutBlock)
# resblock (ResBlock)
# Channel Attention (CA) Layer (CALayer)
# Residual Channel Attention Block (RCABlock)
# Residual Channel Attention Group (RCAGroup)
# Residual Dense Block (ResidualDenseBlock_5C)
# Residual in Residual Dense Block (RRDB)
# --------------------------------------------
"""
# --------------------------------------------
# return nn.Sequantial of (Conv + BN + ReLU)
# --------------------------------------------
def conv(
in_channels=64,
out_channels=64,
kernel_size=3,
stride=1,
padding=1,
bias=True,
mode="CBR",
negative_slope=0.2,
):
L = []
for t in mode:
if t == "C":
L.append(
nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
bias=bias,
)
)
elif t == "T":
L.append(
nn.ConvTranspose2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
bias=bias,
)
)
elif t == "B":
L.append(nn.BatchNorm2d(out_channels, momentum=0.9, eps=1e-04, affine=True))
elif t == "I":
L.append(nn.InstanceNorm2d(out_channels, affine=True))
elif t == "R":
L.append(nn.ReLU(inplace=True))
elif t == "r":
L.append(nn.ReLU(inplace=False))
elif t == "L":
L.append(nn.LeakyReLU(negative_slope=negative_slope, inplace=True))
elif t == "l":
L.append(nn.LeakyReLU(negative_slope=negative_slope, inplace=False))
elif t == "E":
L.append(nn.ELU(inplace=False))
elif t == "s":
L.append(nn.Softplus())
elif t == "2":
L.append(nn.PixelShuffle(upscale_factor=2))
elif t == "3":
L.append(nn.PixelShuffle(upscale_factor=3))
elif t == "4":
L.append(nn.PixelShuffle(upscale_factor=4))
elif t == "U":
L.append(nn.Upsample(scale_factor=2, mode="nearest"))
elif t == "u":
L.append(nn.Upsample(scale_factor=3, mode="nearest"))
elif t == "v":
L.append(nn.Upsample(scale_factor=4, mode="nearest"))
elif t == "M":
L.append(nn.MaxPool2d(kernel_size=kernel_size, stride=stride, padding=0))
elif t == "A":
L.append(nn.AvgPool2d(kernel_size=kernel_size, stride=stride, padding=0))
else:
raise NotImplementedError("Undefined type: ".format(t))
return sequential(*L)
"""
# --------------------------------------------
# Upsampler
# Kai Zhang, https://github.com/cszn/KAIR
# --------------------------------------------
# upsample_pixelshuffle
# upsample_upconv
# upsample_convtranspose
# --------------------------------------------
"""
# --------------------------------------------
# conv + subp (+ relu)
# --------------------------------------------
def upsample_pixelshuffle(
in_channels=64,
out_channels=3,
kernel_size=3,
stride=1,
padding=1,
bias=True,
mode="2R",
negative_slope=0.2,
):
assert len(mode) < 4 and mode[0] in [
"2",
"3",
"4",
], "mode examples: 2, 2R, 2BR, 3, ..., 4BR."
up1 = conv(
in_channels,
out_channels * (int(mode[0]) ** 2),
kernel_size,
stride,
padding,
bias,
mode="C" + mode,
negative_slope=negative_slope,
)
return up1
# --------------------------------------------
# nearest_upsample + conv (+ R)
# --------------------------------------------
def upsample_upconv(
in_channels=64,
out_channels=3,
kernel_size=3,
stride=1,
padding=1,
bias=True,
mode="2R",
negative_slope=0.2,
):
assert len(mode) < 4 and mode[0] in [
"2",
"3",
"4",
], "mode examples: 2, 2R, 2BR, 3, ..., 4BR"
if mode[0] == "2":
uc = "UC"
elif mode[0] == "3":
uc = "uC"
elif mode[0] == "4":
uc = "vC"
mode = mode.replace(mode[0], uc)
up1 = conv(
in_channels,
out_channels,
kernel_size,
stride,
padding,
bias,
mode=mode,
negative_slope=negative_slope,
)
return up1
# --------------------------------------------
# convTranspose (+ relu)
# --------------------------------------------
def upsample_convtranspose(
in_channels=64,
out_channels=3,
kernel_size=2,
stride=2,
padding=0,
bias=True,
mode="2R",
negative_slope=0.2,
):
assert len(mode) < 4 and mode[0] in [
"2",
"3",
"4",
"8",
], "mode examples: 2, 2R, 2BR, 3, ..., 4BR."
kernel_size = int(mode[0])
stride = int(mode[0])
mode = mode.replace(mode[0], "T")
up1 = conv(
in_channels,
out_channels,
kernel_size,
stride,
padding,
bias,
mode,
negative_slope,
)
return up1
"""
# --------------------------------------------
# Downsampler
# Kai Zhang, https://github.com/cszn/KAIR
# --------------------------------------------
# downsample_strideconv
# downsample_maxpool
# downsample_avgpool
# --------------------------------------------
"""
# --------------------------------------------
# strideconv (+ relu)
# --------------------------------------------
def downsample_strideconv(
in_channels=64,
out_channels=64,
kernel_size=2,
stride=2,
padding=0,
bias=True,
mode="2R",
negative_slope=0.2,
):
assert len(mode) < 4 and mode[0] in [
"2",
"3",
"4",
"8",
], "mode examples: 2, 2R, 2BR, 3, ..., 4BR."
kernel_size = int(mode[0])
stride = int(mode[0])
mode = mode.replace(mode[0], "C")
down1 = conv(
in_channels,
out_channels,
kernel_size,
stride,
padding,
bias,
mode,
negative_slope,
)
return down1
# --------------------------------------------
# maxpooling + conv (+ relu)
# --------------------------------------------
def downsample_maxpool(
in_channels=64,
out_channels=64,
kernel_size=3,
stride=1,
padding=0,
bias=True,
mode="2R",
negative_slope=0.2,
):
assert len(mode) < 4 and mode[0] in [
"2",
"3",
], "mode examples: 2, 2R, 2BR, 3, ..., 3BR."
kernel_size_pool = int(mode[0])
stride_pool = int(mode[0])
mode = mode.replace(mode[0], "MC")
pool = conv(
kernel_size=kernel_size_pool,
stride=stride_pool,
mode=mode[0],
negative_slope=negative_slope,
)
pool_tail = conv(
in_channels,
out_channels,
kernel_size,
stride,
padding,
bias,
mode=mode[1:],
negative_slope=negative_slope,
)
return sequential(pool, pool_tail)
# --------------------------------------------
# averagepooling + conv (+ relu)
# --------------------------------------------
def downsample_avgpool(
in_channels=64,
out_channels=64,
kernel_size=3,
stride=1,
padding=1,
bias=True,
mode="2R",
negative_slope=0.2,
):
assert len(mode) < 4 and mode[0] in [
"2",
"3",
], "mode examples: 2, 2R, 2BR, 3, ..., 3BR."
kernel_size_pool = int(mode[0])
stride_pool = int(mode[0])
mode = mode.replace(mode[0], "AC")
pool = conv(
kernel_size=kernel_size_pool,
stride=stride_pool,
mode=mode[0],
negative_slope=negative_slope,
)
pool_tail = conv(
in_channels,
out_channels,
kernel_size,
stride,
padding,
bias,
mode=mode[1:],
negative_slope=negative_slope,
)
return sequential(pool, pool_tail)