import math import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from deepinv.models.unet import BFBatchNorm2d from deepinv.physics.blur import gaussian_blur from deepinv.physics.functional import conv2d from deepinv.utils import TensorList from timm.models.layers import trunc_normal_, DropPath def normalize(x, dim=None, eps=1e-4): if dim is None: dim = list(range(1, x.ndim)) norm = torch.linalg.vector_norm(x, dim=dim, keepdim=True, dtype=torch.float32) norm = torch.add(eps, norm, alpha=np.sqrt(norm.numel() / x.numel())) return x / norm.to(x.dtype) class TimestepEmbedding(nn.Module): def __init__(self, hidden_size, frequency_embedding_size=256): super().__init__() self.mlp = nn.Sequential( nn.Linear(frequency_embedding_size, hidden_size), nn.SiLU(), nn.Linear(hidden_size, hidden_size), ) self.frequency_embedding_size = frequency_embedding_size @staticmethod def timestep_embedding(t, dim, max_period=10000): half = dim // 2 freqs = torch.exp( -math.log(max_period) * torch.arange(start=0, end=half) / half ).to(t.device) args = t[:, None] * freqs[None] embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) if dim % 2: embedding = torch.cat( [embedding, torch.zeros_like(embedding[:, :1])], dim=-1 ) return embedding def forward(self, t): t_freq = self.timestep_embedding(t, self.frequency_embedding_size).to( dtype=next(self.parameters()).dtype ) t_emb = self.mlp(t_freq) return t_emb class MPConv(torch.nn.Module): def __init__(self, in_channels, out_channels, kernel): super().__init__() self.out_channels = out_channels self.weight = torch.nn.Parameter(torch.randn(out_channels, in_channels, *kernel)) def forward(self, x, gain=1): w = self.weight.to(torch.float32) if self.training: with torch.no_grad(): self.weight.copy_(normalize(w)) # forced weight normalization w = normalize(w) # traditional weight normalization w = w * (gain / np.sqrt(w[0].numel())) # magnitude-preserving scaling w = w.to(x.dtype) if w.ndim == 2: return x @ w.t() assert w.ndim == 4 return F.conv2d(x, w, padding=(w.shape[-1] // 2,)) # -------------------------------------------------------------------------------------- def mp_silu(x): return torch.nn.functional.silu(x) / 0.596 class MPFourier(torch.nn.Module): def __init__(self, num_channels, bandwidth=1, device="cpu"): super().__init__() self.register_buffer( "freqs", 2 * np.pi * torch.rand(num_channels, device=device) * bandwidth ) self.register_buffer( "phases", 2 * np.pi * torch.rand(num_channels, device=device) ) def forward(self, x): y = x.to(torch.float32) y = y.ger(self.freqs.to(torch.float32)) y = y + self.phases.to(torch.float32) y = y.cos() * np.sqrt(2) return y.to(x.dtype) class NoiseEmbedding(torch.nn.Module): def __init__(self, num_channels=1, emb_channels=512, device="cpu", biasfree=True): super().__init__() self.emb_fourier = MPFourier(num_channels, device=device) self.emb_noise = MPConv(num_channels, emb_channels, kernel=[]) self.biasfree = biasfree def forward(self, y, physics, factor): if hasattr(physics, "noise_model") and not callable(physics.noise_model): sigma = getattr(physics.noise_model, "sigma", 0.0) else: sigma = 0.0 if isinstance(y, TensorList): sigma = sigma / (y[0].abs().reshape(y[0].size(0),-1).mean(1) + 1e-8) / factor else: sigma = sigma / (y.abs().reshape(y.size(0),-1).mean(1) + 1e-8) / factor emb_four = self.emb_fourier(sigma) emb = self.emb_noise(emb_four) if self.biasfree: emb = F.relu(emb) else: emb = mp_silu(emb) return emb.unsqueeze(-1).unsqueeze(-1) # -------------------------------------------------------------------------------------- class AffineConv2d(nn.Conv2d): def __init__( self, in_channels, out_channels, kernel_size, mode="affine", bias=False, stride=1, padding=0, dilation=1, groups=1, padding_mode="circular", blind=True, ): if mode == "affine": # f(a*x + 1) = a*f(x) + 1 bias = False super().__init__( in_channels, out_channels, kernel_size, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups, padding_mode=padding_mode, ) self.blind = blind self.mode = mode def affine(self, w): """returns new kernels that encode affine combinations""" return ( w.view(self.out_channels, -1).roll(1, 1).view(w.size()) - w + 1 / w[0, ...].numel() ) def forward(self, x): if self.mode != "affine": return super().forward(x) else: kernel = ( self.affine(self.weight) if self.blind else torch.cat( (self.affine(self.weight[:, :-1, :, :]), self.weight[:, -1:, :, :]), dim=1, ) ) padding = tuple( elt for elt in reversed(self.padding) for _ in range(2) ) # used to translate padding arg used by Conv module to the ones used by F.pad padding_mode = ( self.padding_mode if self.padding_mode != "zeros" else "constant" ) # used to translate padding_mode arg used by Conv module to the ones used by F.pad return F.conv2d( F.pad(x, padding, mode=padding_mode), kernel, stride=self.stride, dilation=self.dilation, groups=self.groups, ) # -------------------------------------------------------------------------------------- def kaiser_window(beta, length): """Return the Kaiser window of length `length` and shape parameter `beta`.""" if beta < 0: raise ValueError("beta must be greater than 0") if length < 1: raise ValueError("length must be greater than 0") if length == 1: return torch.tensor([1.0]) half = (length - 1) / 2 n = torch.arange(length) beta = torch.tensor(beta) return torch.i0(beta * torch.sqrt(1 - ((n - half) / half) ** 2)) / torch.i0(beta) def sinc_filter(factor=2, length=11, windowed=True): r""" Anti-aliasing sinc filter multiplied by a Kaiser window. :param float factor: Downsampling factor. :param int length: Length of the filter. """ deltaf = 1 / factor n = torch.arange(length) - (length - 1) / 2 filter = torch.sinc(n / factor) if windowed: A = 2.285 * (length - 1) * 3.14 * deltaf + 7.95 if A <= 21: beta = 0 elif A <= 50: beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21) else: beta = 0.1102 * (A - 8.7) filter = filter * kaiser_window(beta, length) filter = filter.unsqueeze(0) filter = filter * filter.T filter = filter.unsqueeze(0).unsqueeze(0) filter = filter / filter.sum() return filter class EquivMaxPool(nn.Module): r""" Max pooling layer that is equivariant to translations. :param int kernel_size: size of the pooling window. :param int stride: stride of the pooling operation. :param int padding: padding to apply before pooling. :param bool circular_padding: circular padding for the convolutional layers. """ def __init__( self, antialias="gaussian", factor=2, device="cuda", in_channels=64, out_channels=64, bias=False, padding_mode="circular", ): super(EquivMaxPool, self).__init__() self.antialias = antialias if antialias == "gaussian": self.antialias_kernel = gaussian_blur(factor / 3.14).to(device) elif antialias == "sinc": self.antialias_kernel = sinc_filter( factor=factor, length=11, windowed=True ).to(device) self.conv_down = AffineConv2d( in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=bias, padding_mode=padding_mode, groups=1, ) self.conv_up = AffineConv2d( out_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias, padding_mode=padding_mode, groups=1, ) def forward(self, x): return self.downscale(x) def downscale(self, x): r""" Apply the equivariant pooling. :param torch.Tensor x: input tensor. """ B, C, H, W = x.shape x = self.conv_down(x) if self.antialias == "gaussian" or self.antialias == "sinc": x = conv2d(x, self.antialias_kernel, padding="circular") x1 = x[:, :, ::2, ::2].unsqueeze(0) x2 = x[:, :, ::2, 1::2].unsqueeze(0) x3 = x[:, :, 1::2, ::2].unsqueeze(0) x4 = x[:, :, 1::2, 1::2].unsqueeze(0) out = torch.cat([x1, x2, x3, x4], dim=0) # (4, B, C, H/2, W/2) ind = torch.norm(out, dim=(2, 3, 4), p=2) # (4, B) ind = torch.argmax(ind, dim=0) # (B) out = out[ind, torch.arange(B), ...] # (B, C, H/2, W/2) self.ind = ind return out def upscale(self, x): B, C, H, W = x.shape out = torch.zeros((B, C, H * 2, W * 2), device=x.device) out[:, :, ::2, ::2] = x ind = self.ind filter = torch.zeros((B, 1, 2, 2), device=x.device) filter[ind == 0, :, 0, 0] = 1 filter[ind == 1, :, 0, 1] = 1 filter[ind == 2, :, 1, 0] = 1 filter[ind == 3, :, 1, 1] = 1 out = conv2d(out, filter, padding="constant") if self.antialias == "gaussian" or self.antialias == "sinc": out = conv2d(out, self.antialias_kernel, padding="circular") out = self.conv_up(out) return out # -------------------------------------------------------------------------------------- class ConvNextBaseBlock(nn.Module): r""" ConvNeXt Block mimicking DRUNet base layer (Conv + Relu + Conv) Args: in_channels (int): Number of input channels. out_channels (int): Number of output channels. mode (str): Mode for the AffineConv2d (if needed, else ignored). bias (bool): Whether to use bias in convolutions. Default: False. ksize (int): Kernel size for the convolutions. Default: 7. padding_mode (str): Padding mode for convolutions. Default: 'circular'. mult_fact (int): Multiplier factor for expanding the number of channels. residual (bool): Whether to use a residual connection. Default: False. """ def __init__( self, in_channels, out_channels, mode="", bias=False, ksize=7, padding_mode="circular", mult_fact=1, residual=False, ): super().__init__() ### DEPTHWISE SEPARABLE CONVOLUTION: (N,C,H,W) -> (N,4*C,H,W) # depthwise conv with big kernel self.dwconv_a = AffineConv2d( in_channels, in_channels, kernel_size=ksize, padding=ksize // 2, groups=in_channels, padding_mode=padding_mode, bias=bias, mode=mode, ) # depthwise conv with small kernel self.dwconv_a_small = AffineConv2d( in_channels, in_channels, kernel_size=3, padding=3 // 2, groups=in_channels, padding_mode=padding_mode, bias=bias, mode=mode, ) # pointwise conv to change number of channels self.pwconv_a1 = AffineConv2d( in_channels, mult_fact * in_channels, kernel_size=1, stride=1, padding=0, mode=mode, bias=bias, padding_mode=padding_mode, groups=1, ) ### ACTIVATION self.act_a = nn.ReLU() ### POINTWISE CONVOLUTION: (N,4*C,H,W) -> (N,O,H,W) self.pwconv_a2 = AffineConv2d( mult_fact * in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=bias, padding_mode=padding_mode, groups=1, ) ### Needed to match the number of channels : (N,C,H,W) -> (C,O,H,W) self.residual = residual if self.residual: self.residual_conv = AffineConv2d( in_channels, out_channels, kernel_size=1, stride=1, padding=0, groups=1, padding_mode=padding_mode, bias=bias, mode=mode, ) def forward(self, x_in, stream1=None, stream2=None): """Forward with GPU parallelization using multiple cuda streams.""" if stream1 is not None and stream2 is not None: # Use the streams with torch.cuda.stream(stream1): output_a = self.dwconv_a(x_in) # Run the first convolution in stream1 with torch.cuda.stream(stream2): output_a_small = self.dwconv_a_small( x_in ) # Run the second convolution in stream2 # Ensure the streams are synchronized before adding the results torch.cuda.synchronize() x = self.pwconv_a(output_a + output_a_small) else: x = self.dwconv_a(x_in) + self.dwconv_a_small(x_in) # replk 7x7 with 3x3 x = self.pwconv_a1(x) x = self.act_a(x) x = self.pwconv_a2(x) # (N,O,H,W) if self.residual: x = self.residual_conv(x_in) + x return x class ConvNextBlock2(nn.Module): r""" ConvNeXt Block mimicking DRUNet base layer (Conv + Relu + Conv) Args: ??? """ def __init__( self, in_channels, out_channels, mode="affine", bias=False, ksize=7, padding_mode="circular", mult_fact=4, s1=None, s2=None, ): super().__init__() self.block_0 = ConvNextBaseBlock( in_channels, out_channels, mode=mode, bias=bias, ksize=ksize, padding_mode=padding_mode, mult_fact=mult_fact, ) self.block_1 = ConvNextBaseBlock( in_channels, out_channels, mode=mode, bias=bias, ksize=ksize, padding_mode=padding_mode, mult_fact=mult_fact, ) # self.relu = nn.ReLU(inplace=True) # issue with the network when working in FP16 ??? self.relu = nn.ReLU() # cuda stream to parallelize execution of ConvNextBaseBlock self.s1 = s1 self.s2 = s2 def forward(self, input, emb_sigma=None): if self.s1 is not None and self.s2 is not None: x = self.block_0(input, self.s1, self.s2) else: x = self.block_0(input) x = self.relu(x) if self.s1 is not None and self.s2 is not None: x = self.block_1(x, self.s1, self.s2) else: x = self.block_1(x) return x + input class CondResBlock(nn.Module): def __init__( self, in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=False, emb_channels=512, ): super(CondResBlock, self).__init__() assert in_channels == out_channels, "Only support in_channels==out_channels." self.gain = torch.nn.Parameter(torch.tensor([1.0]), requires_grad=True) self.emb_linear = MPConv(emb_channels, out_channels, kernel=[3, 3]) self.conv1 = nn.Conv2d( in_channels, out_channels, kernel_size, stride, padding, bias=bias ) self.conv2 = nn.Conv2d( out_channels, out_channels, kernel_size, stride, padding, bias=bias ) def forward(self, x, emb_sigma): # u = self.conv1(mp_silu(x)) u = self.conv1(F.relu((x))) c = self.emb_linear(emb_sigma, gain=self.gain) + 1 # y = mp_silu(u * c.unsqueeze(2).unsqueeze(3).to(u.dtype)) y = F.relu(u * c.unsqueeze(2).unsqueeze(3).to(u.dtype)) y = self.conv2(y) return x + y """ Functional blocks below """ from collections import OrderedDict import torch import torch.nn as nn """ # -------------------------------------------- # Advanced nn.Sequential # https://github.com/xinntao/BasicSR # -------------------------------------------- """ def sequential(*args): """Advanced nn.Sequential. Args: nn.Sequential, nn.Module Returns: nn.Sequential """ if len(args) == 1: if isinstance(args[0], OrderedDict): raise NotImplementedError("sequential does not support OrderedDict input.") return args[0] # No sequential is needed. modules = [] for module in args: if isinstance(module, nn.Sequential): for submodule in module.children(): modules.append(submodule) elif isinstance(module, nn.Module): modules.append(module) return nn.Sequential(*modules) """ # -------------------------------------------- # Useful blocks # https://github.com/xinntao/BasicSR # -------------------------------- # conv + normaliation + relu (conv) # (PixelUnShuffle) # (ConditionalBatchNorm2d) # concat (ConcatBlock) # sum (ShortcutBlock) # resblock (ResBlock) # Channel Attention (CA) Layer (CALayer) # Residual Channel Attention Block (RCABlock) # Residual Channel Attention Group (RCAGroup) # Residual Dense Block (ResidualDenseBlock_5C) # Residual in Residual Dense Block (RRDB) # -------------------------------------------- """ # -------------------------------------------- # return nn.Sequantial of (Conv + BN + ReLU) # -------------------------------------------- def conv( in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=True, mode="CBR", negative_slope=0.2, ): L = [] for t in mode: if t == "C": L.append( nn.Conv2d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias, ) ) elif t == "T": L.append( nn.ConvTranspose2d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias, ) ) elif t == "B": L.append(nn.BatchNorm2d(out_channels, momentum=0.9, eps=1e-04, affine=True)) elif t == "I": L.append(nn.InstanceNorm2d(out_channels, affine=True)) elif t == "R": L.append(nn.ReLU(inplace=True)) elif t == "r": L.append(nn.ReLU(inplace=False)) elif t == "L": L.append(nn.LeakyReLU(negative_slope=negative_slope, inplace=True)) elif t == "l": L.append(nn.LeakyReLU(negative_slope=negative_slope, inplace=False)) elif t == "E": L.append(nn.ELU(inplace=False)) elif t == "s": L.append(nn.Softplus()) elif t == "2": L.append(nn.PixelShuffle(upscale_factor=2)) elif t == "3": L.append(nn.PixelShuffle(upscale_factor=3)) elif t == "4": L.append(nn.PixelShuffle(upscale_factor=4)) elif t == "U": L.append(nn.Upsample(scale_factor=2, mode="nearest")) elif t == "u": L.append(nn.Upsample(scale_factor=3, mode="nearest")) elif t == "v": L.append(nn.Upsample(scale_factor=4, mode="nearest")) elif t == "M": L.append(nn.MaxPool2d(kernel_size=kernel_size, stride=stride, padding=0)) elif t == "A": L.append(nn.AvgPool2d(kernel_size=kernel_size, stride=stride, padding=0)) else: raise NotImplementedError("Undefined type: ".format(t)) return sequential(*L) """ # -------------------------------------------- # Upsampler # Kai Zhang, https://github.com/cszn/KAIR # -------------------------------------------- # upsample_pixelshuffle # upsample_upconv # upsample_convtranspose # -------------------------------------------- """ # -------------------------------------------- # conv + subp (+ relu) # -------------------------------------------- def upsample_pixelshuffle( in_channels=64, out_channels=3, kernel_size=3, stride=1, padding=1, bias=True, mode="2R", negative_slope=0.2, ): assert len(mode) < 4 and mode[0] in [ "2", "3", "4", ], "mode examples: 2, 2R, 2BR, 3, ..., 4BR." up1 = conv( in_channels, out_channels * (int(mode[0]) ** 2), kernel_size, stride, padding, bias, mode="C" + mode, negative_slope=negative_slope, ) return up1 # -------------------------------------------- # nearest_upsample + conv (+ R) # -------------------------------------------- def upsample_upconv( in_channels=64, out_channels=3, kernel_size=3, stride=1, padding=1, bias=True, mode="2R", negative_slope=0.2, ): assert len(mode) < 4 and mode[0] in [ "2", "3", "4", ], "mode examples: 2, 2R, 2BR, 3, ..., 4BR" if mode[0] == "2": uc = "UC" elif mode[0] == "3": uc = "uC" elif mode[0] == "4": uc = "vC" mode = mode.replace(mode[0], uc) up1 = conv( in_channels, out_channels, kernel_size, stride, padding, bias, mode=mode, negative_slope=negative_slope, ) return up1 # -------------------------------------------- # convTranspose (+ relu) # -------------------------------------------- def upsample_convtranspose( in_channels=64, out_channels=3, kernel_size=2, stride=2, padding=0, bias=True, mode="2R", negative_slope=0.2, ): assert len(mode) < 4 and mode[0] in [ "2", "3", "4", "8", ], "mode examples: 2, 2R, 2BR, 3, ..., 4BR." kernel_size = int(mode[0]) stride = int(mode[0]) mode = mode.replace(mode[0], "T") up1 = conv( in_channels, out_channels, kernel_size, stride, padding, bias, mode, negative_slope, ) return up1 """ # -------------------------------------------- # Downsampler # Kai Zhang, https://github.com/cszn/KAIR # -------------------------------------------- # downsample_strideconv # downsample_maxpool # downsample_avgpool # -------------------------------------------- """ # -------------------------------------------- # strideconv (+ relu) # -------------------------------------------- def downsample_strideconv( in_channels=64, out_channels=64, kernel_size=2, stride=2, padding=0, bias=True, mode="2R", negative_slope=0.2, ): assert len(mode) < 4 and mode[0] in [ "2", "3", "4", "8", ], "mode examples: 2, 2R, 2BR, 3, ..., 4BR." kernel_size = int(mode[0]) stride = int(mode[0]) mode = mode.replace(mode[0], "C") down1 = conv( in_channels, out_channels, kernel_size, stride, padding, bias, mode, negative_slope, ) return down1 # -------------------------------------------- # maxpooling + conv (+ relu) # -------------------------------------------- def downsample_maxpool( in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=0, bias=True, mode="2R", negative_slope=0.2, ): assert len(mode) < 4 and mode[0] in [ "2", "3", ], "mode examples: 2, 2R, 2BR, 3, ..., 3BR." kernel_size_pool = int(mode[0]) stride_pool = int(mode[0]) mode = mode.replace(mode[0], "MC") pool = conv( kernel_size=kernel_size_pool, stride=stride_pool, mode=mode[0], negative_slope=negative_slope, ) pool_tail = conv( in_channels, out_channels, kernel_size, stride, padding, bias, mode=mode[1:], negative_slope=negative_slope, ) return sequential(pool, pool_tail) # -------------------------------------------- # averagepooling + conv (+ relu) # -------------------------------------------- def downsample_avgpool( in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=True, mode="2R", negative_slope=0.2, ): assert len(mode) < 4 and mode[0] in [ "2", "3", ], "mode examples: 2, 2R, 2BR, 3, ..., 3BR." kernel_size_pool = int(mode[0]) stride_pool = int(mode[0]) mode = mode.replace(mode[0], "AC") pool = conv( kernel_size=kernel_size_pool, stride=stride_pool, mode=mode[0], negative_slope=negative_slope, ) pool_tail = conv( in_channels, out_channels, kernel_size, stride, padding, bias, mode=mode[1:], negative_slope=negative_slope, ) return sequential(pool, pool_tail)