# Code borrowed from Kai Zhang https://github.com/cszn/DPIR/tree/master/models import re import math import functools import deepinv as dinv from deepinv.utils import plot, TensorList import torch from torch.func import vmap import torch.nn as nn import torch.nn.functional as F from torchvision import transforms from deepinv.optim.utils import conjugate_gradient from physics.multiscale import MultiScaleLinearPhysics, Pad from models.blocks import EquivMaxPool, AffineConv2d, ConvNextBlock2, NoiseEmbedding, MPConv, TimestepEmbedding, conv, downsample_strideconv, upsample_convtranspose from models.heads import Heads, Tails, InHead, OutTail, ConvChannels, SNRModule, EquivConvModule, EquivHeads cuda = True if torch.cuda.is_available() else False Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor ### --------------- MODEL --------------- class BaseEncBlock(nn.Module): def __init__( self, in_channels, out_channels, bias=False, mode="CRC", nb=2, embedding=False, emb_channels=None, emb_physics=False, img_channels=None, decode_upscale=None, config='A', N=4, c_mult=1, depth_encoding=1, relu_in_encoding=False, skip_in_encoding=True, ): super(BaseEncBlock, self).__init__() self.config = config self.enc = nn.ModuleList( [ ResBlock( in_channels, out_channels, bias=bias, mode=mode, embedding=embedding, emb_channels=emb_channels, emb_physics=emb_physics, img_channels=img_channels, decode_upscale=decode_upscale, config=config, N=N, c_mult=c_mult, depth_encoding=depth_encoding, relu_in_encoding=relu_in_encoding, skip_in_encoding=skip_in_encoding, ) for _ in range(nb) ] ) def forward(self, x, emb_sigma=None, physics=None, t=None, y=None, emb_in=None, img_channels=None, scale=0): for i in range(len(self.enc)): x = self.enc[i](x, emb_sigma=emb_sigma, physics=physics, t=t, y=y, img_channels=img_channels, scale=scale) return x class NextEncBlock(nn.Module): def __init__( self, in_channels, out_channels, bias=False, mode="", mult_fact=4, nb=2 ): super(NextEncBlock, self).__init__() self.enc = nn.ModuleList( [ ConvNextBlock2( in_channels=in_channels, out_channels=out_channels, bias=bias, mode=mode, mult_fact=mult_fact, ) for _ in range(nb) ] ) def forward(self, x, emb_sigma=None): for i in range(len(self.enc)): x = self.enc[i](x, emb_sigma) return x class UNeXt(nn.Module): r""" DRUNet denoiser network. The network architecture is based on the paper `Learning deep CNN denoiser prior for image restoration `_, and has a U-Net like structure, with convolutional blocks in the encoder and decoder parts. The network takes into account the noise level of the input image, which is encoded as an additional input channel. A pretrained network for (in_channels=out_channels=1 or in_channels=out_channels=3) can be downloaded via setting ``pretrained='download'``. :param int in_channels: number of channels of the input. :param int out_channels: number of channels of the output. :param list nc: number of convolutional layers. :param int nb: number of convolutional blocks per layer. :param int nf: number of channels per convolutional layer. :param str act_mode: activation mode, "R" for ReLU, "L" for LeakyReLU "E" for ELU and "S" for Softplus. :param str downsample_mode: Downsampling mode, "avgpool" for average pooling, "maxpool" for max pooling, and "strideconv" for convolution with stride 2. :param str upsample_mode: Upsampling mode, "convtranspose" for convolution transpose, "pixelsuffle" for pixel shuffling, and "upconv" for nearest neighbour upsampling with additional convolution. :param str, None pretrained: use a pretrained network. If ``pretrained=None``, the weights will be initialized at random using Pytorch's default initialization. If ``pretrained='download'``, the weights will be downloaded from an online repository (only available for the default architecture with 3 or 1 input/output channels). Finally, ``pretrained`` can also be set as a path to the user's own pretrained weights. See :ref:`pretrained-weights ` for more details. :param bool train: training or testing mode. :param str device: gpu or cpu. """ def __init__( self, in_channels=[1, 2, 3], out_channels=[1, 2, 3], nc=[64, 128, 256, 512], nb=4, # 4 in DRUNet but out of memory conv_type="next", # should be 'base' or 'next' pool_type="next", # should be 'base' or 'next' cond_type="base", # conditioning, should be 'base' or 'edm' device=None, bias=False, mode="", residual=False, act_mode="R", layer_scale_init_value=1e-6, init_type="ortho", gain_init_conv=1.0, gain_init_linear=1.0, drop_prob=0.0, replk=False, mult_fact=4, antialias="gaussian", emb_physics=False, config='A', pretrained_pth=None, N=4, c_mult=1, depth_encoding=1, relu_in_encoding=False, skip_in_encoding=True, ): super(UNeXt, self).__init__() self.residual = residual self.conv_type = conv_type self.pool_type = pool_type self.emb_physics = emb_physics self.config = config self.in_channels = in_channels self.fact_realign = torch.nn.Parameter(torch.tensor([1.0], device=device)) self.separate_head = isinstance(in_channels, list) assert cond_type in ["base", "edm"], "cond_type should be 'base' or 'edm'" self.cond_type = cond_type if self.cond_type == "base": if self.config != 'E': if isinstance(in_channels, list): in_channels_first = [] for i in range(len(in_channels)): in_channels_first.append(in_channels[i] + 2) else: # old head in_channels_first = in_channels + 1 else: in_channels_first = in_channels else: in_channels_first = in_channels self.noise_embedding = NoiseEmbedding( num_channels=in_channels, emb_channels=max(nc), device=device ) self.timestep_embedding = lambda x: x # check if in_channels is a list self.m_head = InHead(in_channels_first, nc[0]) if conv_type == "next": self.m_down1 = NextEncBlock( nc[0], nc[0], bias=bias, mode=mode, mult_fact=mult_fact, nb=nb ) self.m_down2 = NextEncBlock( nc[1], nc[1], bias=bias, mode=mode, mult_fact=mult_fact, nb=nb ) self.m_down3 = NextEncBlock( nc[2], nc[2], bias=bias, mode=mode, mult_fact=mult_fact, nb=nb ) self.m_body = NextEncBlock( nc[3], nc[3], bias=bias, mode=mode, mult_fact=mult_fact, nb=nb ) self.m_up3 = NextEncBlock( nc[2], nc[2], bias=bias, mode=mode, mult_fact=mult_fact, nb=nb ) self.m_up2 = NextEncBlock( nc[1], nc[1], bias=bias, mode=mode, mult_fact=mult_fact, nb=nb ) self.m_up1 = NextEncBlock( nc[0], nc[0], bias=bias, mode=mode, mult_fact=mult_fact, nb=nb ) elif conv_type == "base": embedding = ( False if cond_type == "base" else True ) emb_channels = max(nc) self.m_down1 = BaseEncBlock( nc[0], nc[0], bias=False, mode="CRC", nb=nb, embedding=embedding, emb_channels=emb_channels, emb_physics=emb_physics, img_channels=in_channels, decode_upscale=1, config=config, N=N, c_mult=c_mult, depth_encoding=depth_encoding, relu_in_encoding=relu_in_encoding, skip_in_encoding=skip_in_encoding, ) self.m_down2 = BaseEncBlock( nc[1], nc[1], bias=False, mode="CRC", nb=nb, embedding=embedding, emb_channels=emb_channels, emb_physics=emb_physics, img_channels=in_channels, decode_upscale=2, config=config, N=N, c_mult=c_mult, depth_encoding=depth_encoding, relu_in_encoding=relu_in_encoding, skip_in_encoding=skip_in_encoding, ) self.m_down3 = BaseEncBlock( nc[2], nc[2], bias=False, mode="CRC", nb=nb, embedding=embedding, emb_channels=emb_channels, emb_physics=emb_physics, img_channels=in_channels, decode_upscale=4, config=config, N=N, c_mult=c_mult, depth_encoding=depth_encoding, relu_in_encoding=relu_in_encoding, skip_in_encoding=skip_in_encoding, ) self.m_body = BaseEncBlock( nc[3], nc[3], bias=False, mode="CRC", nb=nb, embedding=embedding, emb_channels=emb_channels, emb_physics=emb_physics, img_channels=in_channels, decode_upscale=8, config=config, N=N, c_mult=c_mult, depth_encoding=depth_encoding, relu_in_encoding=relu_in_encoding, skip_in_encoding=skip_in_encoding, ) self.m_up3 = BaseEncBlock( nc[2], nc[2], bias=False, mode="CRC", nb=nb, embedding=embedding, emb_channels=emb_channels, emb_physics=emb_physics, img_channels=in_channels, decode_upscale=4, config=config, N=N, c_mult=c_mult, depth_encoding=depth_encoding, relu_in_encoding=relu_in_encoding, skip_in_encoding=skip_in_encoding, ) self.m_up2 = BaseEncBlock( nc[1], nc[1], bias=False, mode="CRC", nb=nb, embedding=embedding, emb_channels=emb_channels, emb_physics=emb_physics, img_channels=in_channels, decode_upscale=2, config=config, N=N, c_mult=c_mult, depth_encoding=depth_encoding, relu_in_encoding=relu_in_encoding, skip_in_encoding=skip_in_encoding, ) self.m_up1 = BaseEncBlock( nc[0], nc[0], bias=False, mode="CRC", nb=nb, embedding=embedding, emb_channels=emb_channels, emb_physics=emb_physics, img_channels=in_channels, decode_upscale=1, config=config, N=N, c_mult=c_mult, depth_encoding=depth_encoding, relu_in_encoding=relu_in_encoding, skip_in_encoding=skip_in_encoding, ) else: raise NotImplementedError("conv_type should be 'base' or 'next'") if pool_type == "next_max": self.pool1 = EquivMaxPool( antialias=antialias, in_channels=nc[0], out_channels=nc[1], device=device, ) self.pool2 = EquivMaxPool( antialias=antialias, in_channels=nc[1], out_channels=nc[2], device=device, ) self.pool3 = EquivMaxPool( antialias=antialias, in_channels=nc[2], out_channels=nc[3], device=device, ) elif pool_type == "base": self.pool1 = downsample_strideconv(nc[0], nc[1], bias=False, mode="2") self.pool2 = downsample_strideconv(nc[1], nc[2], bias=False, mode="2") self.pool3 = downsample_strideconv(nc[2], nc[3], bias=False, mode="2") self.up3 = upsample_convtranspose(nc[3], nc[2], bias=False, mode="2") self.up2 = upsample_convtranspose(nc[2], nc[1], bias=False, mode="2") self.up1 = upsample_convtranspose(nc[1], nc[0], bias=False, mode="2") else: raise NotImplementedError("pool_type should be 'base' or 'next'") self.m_tail = OutTail(nc[0], in_channels) if conv_type == "base": init_func = functools.partial( weights_init_unext, init_type="ortho", gain_conv=0.2 ) self.apply(init_func) else: init_func = functools.partial( weights_init_unext, init_type=init_type, gain_conv=gain_init_conv, gain_linear=gain_init_linear, ) self.apply(init_func) if pretrained_pth=='jz': pth = '/lustre/fswork/projects/rech/nyd/commun/mterris/base_checkpoints/drunet_deepinv_color_finetune_22k.pth' self.load_drunet_weights(pth) elif pretrained_pth is not None: self.load_drunet_weights(pretrained_pth) if self.config == 'D': # deactivate grad for layers that do not contain the string "PhysicsBlock" or "gain" or "fact_realign" for name, param in self.named_parameters(): if 'PhysicsBlock' not in name and 'gain' not in name and 'fact_realign' not in name and "m_head" not in name and "m_tail" not in name: param.requires_grad = False if device is not None: self.to(device) def load_drunet_weights(self, ckpt_pth): state_dict = torch.load(ckpt_pth, map_location=lambda storage, loc: storage) new_state_dict = {} matched_keys = [] # List to store successfully matched keys unmatched_keys = [] # List to store keys that were not matched or excluded excluded_keys = [] # List to store excluded keys # Define patterns to exclude exclude_patterns = ["head", "tail"] # Dealing with regular keys for old_key, value in state_dict.items(): # Skip keys containing any of the excluded patterns if any(excluded in old_key for excluded in exclude_patterns): excluded_keys.append(old_key) continue # Skip further processing for this key new_key = old2new(old_key) if new_key is not None: matched_keys.append((old_key, new_key)) # Record the matched keys new_state_dict[new_key] = value else: unmatched_keys.append(old_key) # Record unmatched keys # TODO: clean this for excluded_key in excluded_keys: if isinstance(self.in_channels, list): for i, in_channel in enumerate(self.in_channels): # print('Dealing with conv ', i) new_key = f"m_head.conv{i}.weight" if 'head' in excluded_key: new_key = f"m_head.conv{i}.weight" # new_key = f"m_head.head.conv{i}.weight" if 'tail' in excluded_key: new_key = f"m_tail.conv{i}.weight" # DEBUG print all keys of state dict: # print(state_dict.keys()) # print(self.state_dict().keys()) conditioning = 'base' # if self.config == 'E': # conditioning = False new_kv = update_keyvals_headtail(excluded_key, state_dict[excluded_key], init_value=self.state_dict()[new_key], new_key_name=new_key, conditioning=conditioning) new_state_dict.update(new_kv) # print(new_kv.keys()) else: new_kv = update_keyvals_headtail(excluded_key, state_dict[excluded_key]) new_state_dict.update(new_kv) # Display matched keys print("Matched keys:") for old_key, new_key in matched_keys: print(f"{old_key} -> {new_key}") # Load updated state dict into the model self.load_state_dict(new_state_dict, strict=False) # Display unmatched keys print("\nUnmatched keys:") for unmatched_key in unmatched_keys: print(unmatched_key) print("Weights loaded from ", ckpt_pth) def constant2map(self, value, x): if isinstance(value, torch.Tensor): if value.ndim > 0: value_map = value.view(x.size(0), 1, 1, 1) value_map = value_map.expand(-1, 1, x.size(2), x.size(3)) else: value_map = torch.ones( (x.size(0), 1, x.size(2), x.size(3)), device=x.device ) * value[None, None, None, None].to(x.device) else: value_map = ( torch.ones((x.size(0), 1, x.size(2), x.size(3)), device=x.device) * value ) return value_map def base_conditioning(self, x, sigma, gamma): noise_level_map = self.constant2map(sigma, x) gamma_map = self.constant2map(gamma, x) return torch.cat((x, noise_level_map, gamma_map), 1) def realign_input(self, x, physics, y): if hasattr(physics, "factor"): f = physics.factor elif hasattr(physics, "base") and hasattr(physics.base, "factor"): f = physics.base.factor elif hasattr(physics, "base") and hasattr(physics.base, "base") and hasattr(physics.base.base, "factor"): f = physics.base.base.factor else: f = 1.0 sigma = 1e-6 # default value if hasattr(physics.noise_model, 'sigma'): sigma = physics.noise_model.sigma if hasattr(physics, 'base') and hasattr(physics.base, 'noise_model') and hasattr(physics.base.noise_model, 'sigma'): sigma = physics.base.noise_model.sigma if hasattr(physics, 'base') and hasattr(physics.base, 'base') and hasattr(physics.base.base, 'noise_model') and hasattr(physics.base.base.noise_model, 'sigma'): sigma = physics.base.base.noise_model.sigma if isinstance(y, TensorList): num = (y[0].reshape(y[0].shape[0], -1).abs().mean(1)) else: num = (y.reshape(y.shape[0], -1).abs().mean(1)) snr = num / (sigma + 1e-4) # SNR equivariant gamma = 1 / (1e-4 + 1 / (snr * f **2 )) # TODO: check square-root / mean / check if we need to add a factor in front ? gamma = gamma[(...,) + (None,) * (x.dim() - 1)] model_input = physics.prox_l2(x, y, gamma=gamma * self.fact_realign) return model_input def forward_unet(self, x0, sigma=None, gamma=None, physics=None, t=None, y=None, img_channels=None): # list_values = [] if self.cond_type == "base": # if self.config != 'E': x0 = self.base_conditioning(x0, sigma, gamma) emb_sigma = None else: emb_sigma = self.noise_embedding( sigma ) # This only if the embedding is the non-basic one from drunet emb_timestep = self.timestep_embedding(t) x1 = self.m_head(x0) # old # x1 = self.m_head(x0, emb_sigma=emb_sigma, physics=physics, t=emb_timestep, y=y, img_channels=img_channels) # list_values.append(x1.abs().mean()) if self.config == 'G': x1_, emb1_ = self.m_down1(x1, emb_sigma=emb_sigma, physics=physics, t=emb_timestep, y=y, img_channels=img_channels) else: x1_ = self.m_down1(x1, emb_sigma=emb_sigma, physics=physics, t=emb_timestep, y=y, img_channels=img_channels, scale=0) x2 = self.pool1(x1_) # list_values.append(x2.abs().mean()) if self.config == 'G': x3_, emb3_ = self.m_down2(x2, emb_sigma=emb_sigma, physics=physics, t=emb_timestep, y=y, img_channels=img_channels) else: x3_ = self.m_down2(x2, emb_sigma=emb_sigma, physics=physics, t=emb_timestep, y=y, img_channels=img_channels, scale=1) x3 = self.pool2(x3_) # list_values.append(x3.abs().mean()) if self.config == 'G': x4_, emb4_ = self.m_down3(x3, emb_sigma=emb_sigma, physics=physics, t=emb_timestep, y=y, img_channels=img_channels) else: x4_ = self.m_down3(x3, emb_sigma=emb_sigma, physics=physics, t=emb_timestep, y=y, img_channels=img_channels, scale=2) x4 = self.pool3(x4_) # issue: https://github.com/matthieutrs/ram_project/issues/1 # solution 1: using .contiguous() below # solution 2: using a print statement that magically solves the issue ###print(x4.is_contiguous()) # list_values.append(x4.abs().mean()) if self.config == 'G': x, _ = self.m_body(x4, emb_sigma=emb_sigma, physics=physics, t=emb_timestep, y=y, img_channels=img_channels) else: x = self.m_body(x4, emb_sigma=emb_sigma, physics=physics, t=emb_timestep, y=y, img_channels=img_channels, scale=3) # list_values.append(x.abs().mean()) if self.pool_type == "next" or self.pool_type == "next_max": x = self.pool3.upscale(x + x4) else: x = self.up3(x + x4) if self.config == 'G': x, _ = self.m_up3(x, emb_sigma=emb_sigma, physics=physics, t=emb_timestep, y=y, emb_in=emb4_, img_channels=img_channels) else: x = self.m_up3(x, emb_sigma=emb_sigma, physics=physics, t=emb_timestep, y=y, img_channels=img_channels, scale=2) # list_values.append(x.abs().mean()) if self.pool_type == "next" or self.pool_type == "next_max": x = self.pool2.upscale(x + x3) else: x = self.up2(x + x3) if self.config == 'G': x, _ = self.m_up2(x, emb_sigma=emb_sigma, physics=physics, t=emb_timestep, y=y, emb_in=emb3_, img_channels=img_channels) else: x = self.m_up2(x, emb_sigma=emb_sigma, physics=physics, t=emb_timestep, y=y, img_channels=img_channels, scale=1) # list_values.append(x.abs().mean()) if self.pool_type == "next" or self.pool_type == "next_max": x = self.pool1.upscale(x + x2) else: x = self.up1(x + x2) if self.config == 'G': x, _ = self.m_up1(x, emb_sigma=emb_sigma, physics=physics, t=emb_timestep, y=y, emb_in=emb1_, img_channels=img_channels) else: x = self.m_up1(x, emb_sigma=emb_sigma, physics=physics, t=emb_timestep, y=y, img_channels=img_channels, scale=0) # list_values.append(x.abs().mean()) if self.separate_head: x = self.m_tail(x + x1, img_channels) else: x = self.m_tail(x + x1) return x def forward(self, x, sigma=None, gamma=None, physics=None, t=None, y=None): r""" Run the denoiser on image with noise level :math:`\sigma`. :param torch.Tensor x: noisy image :param float, torch.Tensor sigma: noise level. If ``sigma`` is a float, it is used for all images in the batch. If ``sigma`` is a tensor, it must be of shape ``(batch_size,)``. """ img_channels = x.shape[1] # x_n_chan = x.shape[1] if self.emb_physics: physics = MultiScaleLinearPhysics(physics, x.shape[-3:], device=x.device) if self.separate_head and img_channels not in self.in_channels: raise ValueError(f"Input image has {img_channels} channels, but the network only have heads for {self.in_channels} channels.") if y is not None: x = self.realign_input(x, physics, y) x = self.forward_unet(x, sigma=sigma, gamma=gamma, physics=physics, t=t, y=y, img_channels=img_channels) return x def krylov_embeddings_old(y, p, factor, v=None, N=4, feat_size=1, x_init=None, img_channels=3): if x_init is None: x = p.A_adjoint(y) else: x = x_init[:, :img_channels, ...] if feat_size > 1: _, C, _, _ = x.shape if v is None: v = torch.zeros_like(x).repeat(1, N-1, 1, 1) out = x - v[:, :C, ...] norm = factor ** 2 A = lambda u: p.A_adjoint(p.A(u)) * norm for i in range(N-1): x = A(x) - v[:, (i+1) * C:(i+2) * C, ...] out = torch.cat([out, x], dim=1) else: if v is None: v = torch.zeros_like(x) out = x - v norm = factor ** 2 A = lambda u: p.A_adjoint(p.A(u)) * norm for i in range(N-1): x = A(x) - v out = torch.cat([out, x], dim=1) return out def krylov_embeddings(y, p, factor, v=None, N=4, x_init=None, img_channels=3): """ Efficient Krylov subspace embedding computation with parallel processing. Args: y (torch.Tensor): The input tensor. p: An object with A and A_adjoint methods (linear operator). factor (float): Scaling factor. v (torch.Tensor, optional): Precomputed values to subtract from Krylov sequence. Defaults to None. N (int, optional): Number of Krylov iterations. Defaults to 4. feat_size (int, optional): Feature expansion size. Defaults to 1. x_init (torch.Tensor, optional): Initial guess. Defaults to None. img_channels (int, optional): Number of image channels. Defaults to 3. Returns: torch.Tensor: The Krylov embeddings. """ if x_init is None: x = p.A_adjoint(y) else: x = x_init.clone() # Extract the first img_channels norm = factor ** 2 # Precompute normalization factor AtA = lambda u: p.A_adjoint(p.A(u)) * norm # Define the linear operator v = v if v is not None else torch.zeros_like(x) out = x.clone() # Compute Krylov basis x_k = x.clone() for i in range(N-1): x_k = AtA(x_k) - v out = torch.cat([out, x_k], dim=1) return out def grad_embeddings(y, p, factor, v=None, N=4, feat_size=1): Aty = p.A_adjoint(y) if feat_size > 1: _, C, _, _ = Aty.shape if v is None: v = torch.zeros_like(Aty).repeat(1, N-1, 1, 1) out = v[:, :C, ...] - Aty norm = factor ** 2 A = lambda u: p.A_adjoint(p.A(u)) * norm for i in range(N-1): x = A(v[:, (i+1) * C:(i+2) * C, ...]) - Aty out = torch.cat([out, x], dim=1) else: if v is None: v = torch.zeros_like(Aty) out = v - Aty norm = factor ** 2 A = lambda u: p.A_adjoint(p.A(u)) * norm for i in range(N-1): x = A(v) - Aty out = torch.cat([out, x], dim=1) return out def prox_embeddings(y, p, factor, v=None, N=4): x = p.A_adjoint(y) B, C, H, W = x.shape if v is None: v = torch.zeros_like(x) v = v.repeat(1, N - 1, 1, 1) gamma = torch.logspace(-4, -1, N-1, device=x.device).repeat_interleave(C).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) norm = factor ** 2 A_sub = lambda u: torch.cat([p.A_adjoint(p.A(u[:, i * C:(i+1) * C, ...])) * norm for i in range(N-1)], dim=1) A = lambda u: A_sub(u) + (u - v) * gamma u_hat = conjugate_gradient(A, x.repeat(1, N-1, 1, 1), max_iter=3, tol=1e-3) u_hat = torch.cat([u_hat, x], dim=1) return u_hat # -------------------------------------------- # Res Block: x + conv(relu(conv(x))) # -------------------------------------------- class MeasCondBlock(nn.Module): def __init__( self, out_channels=64, img_channels=None, decode_upscale=None, config = 'A', N=4, depth_encoding=1, relu_in_encoding=False, skip_in_encoding=True, c_mult=1, ): super(MeasCondBlock, self).__init__() self.separate_head = isinstance(img_channels, list) self.config = config assert img_channels is not None, "decode_dimensions should be provided" assert decode_upscale is not None, "decode_upscale should be provided" # if self.separate_head: if self.config == 'A': self.relu_encoding = nn.ReLU(inplace=False) self.N = N self.c_mult = c_mult self.encoding_conv = Heads(img_channels, out_channels, depth=depth_encoding, scale=1, bias=False, c_mult=self.c_mult, relu_in=relu_in_encoding, skip_in=skip_in_encoding) if self.config == 'B': self.N = N self.c_mult = c_mult self.relu_encoding = nn.ReLU(inplace=False) self.decoding_conv = Tails(out_channels, img_channels, depth=1, scale=1, bias=False, c_mult=self.c_mult) self.encoding_conv = Heads(img_channels, out_channels, depth=depth_encoding, scale=1, bias=False, c_mult=self.c_mult, relu_in=relu_in_encoding, skip_in=skip_in_encoding) if self.config == 'C': self.N = N self.c_mult = c_mult self.relu_encoding = nn.ReLU(inplace=False) self.decoding_conv = Tails(out_channels, img_channels, depth=1, scale=1, bias=False, c_mult=self.c_mult) self.encoding_conv = Heads(img_channels, out_channels, depth=depth_encoding, scale=1, bias=False, c_mult=self.c_mult*N, c_add=N, relu_in=relu_in_encoding, skip_in=skip_in_encoding) elif self.config == 'D': self.N = N self.c_mult = c_mult self.relu_encoding = nn.ReLU(inplace=False) self.decoding_conv = Tails(out_channels, img_channels, depth=1, scale=1, bias=False, c_mult=self.c_mult) self.encoding_conv = Heads(img_channels, out_channels, depth=depth_encoding, scale=1, bias=False, c_mult=self.c_mult*N, c_add=N, relu_in=relu_in_encoding, skip_in=skip_in_encoding) self.gain = torch.nn.Parameter(torch.tensor([1.0]), requires_grad=True) self.gain_gradx = torch.nn.Parameter(torch.tensor([1e-2]), requires_grad=True) self.gain_grady = torch.nn.Parameter(torch.tensor([1e-2]), requires_grad=True) self.gain_pinvx = torch.nn.Parameter(torch.tensor([1e-2]), requires_grad=True) self.gain_pinvy = torch.nn.Parameter(torch.tensor([1e-2]), requires_grad=True) def forward(self, x, y, physics, t, emb_in=None, img_channels=None, scale=1): if self.config == 'A': return self.measurement_conditioning_config_A(x, y, physics, img_channels=img_channels, scale=scale) elif self.config == 'F': return self.measurement_conditioning_config_F(x, y, physics, img_channels=img_channels, scale=scale) elif self.config == 'B': return self.measurement_conditioning_config_B(x, y, physics, img_channels=img_channels, scale=scale) elif self.config == 'C': return self.measurement_conditioning_config_C(x, y, physics, img_channels=img_channels, scale=scale) elif self.config == 'D': return self.measurement_conditioning_config_D(x, y, physics, img_channels=img_channels, scale=scale) elif self.config == 'E': return self.measurement_conditioning_config_E(x, y, physics, img_channels=img_channels, scale=scale) else: raise NotImplementedError('Config not implemented') def measurement_conditioning_config_A(self, x, y, physics, img_channels, scale=0): physics.set_scale(scale) factor = 2**(scale) meas = krylov_embeddings(y, physics, factor, N=self.N, img_channels=img_channels) cond = self.encoding_conv(meas) emb = self.relu_encoding(cond) return emb def measurement_conditioning_config_B(self, x, y, physics, img_channels, scale=0): physics.set_scale(scale) dec = self.decoding_conv(x, img_channels) factor = 2**(scale) meas = krylov_embeddings(y, physics, factor, v=dec, N=self.N, img_channels=img_channels) cond = self.encoding_conv(meas) emb = self.relu_encoding(cond) return emb # * sigma_emb def measurement_conditioning_config_C(self, x, y, physics, img_channels, scale=0): physics.set_scale(scale) dec = self.decoding_conv(x, img_channels) factor = 2**(scale) meas_y = krylov_embeddings(y, physics, factor, N=self.N, img_channels=img_channels) meas_dec = krylov_embeddings(y, physics, factor, N=self.N, x_init=dec[:, :img_channels, ...], img_channels=img_channels) for c in range(1, self.c_mult): meas_cur = krylov_embeddings(y, physics, factor, N=self.N, x_init=dec[:, img_channels*c:img_channels*(c+1)], img_channels=img_channels) meas_dec = torch.cat([meas_dec, meas_cur], dim=1) meas = torch.cat([meas_y, meas_dec], dim=1) cond = self.encoding_conv(meas) emb = self.relu_encoding(cond) return emb def measurement_conditioning_config_D(self, x, y, physics, img_channels, scale=0): physics.set_scale(scale) dec = self.decoding_conv(x, img_channels) factor = 2**(scale) meas_y = krylov_embeddings(y, physics, factor, N=self.N, img_channels=img_channels) meas_dec = krylov_embeddings(y, physics, factor, N=self.N, x_init=dec[:, :img_channels, ...], img_channels=img_channels) for c in range(1, self.c_mult): meas_cur = krylov_embeddings(y, physics, factor, N=self.N, x_init=dec[:, img_channels*c:img_channels*(c+1)], img_channels=img_channels) meas_dec = torch.cat([meas_dec, meas_cur], dim=1) meas = torch.cat([meas_y, meas_dec], dim=1) cond = self.encoding_conv(meas) emb = self.relu_encoding(cond) return cond def measurement_conditioning_config_F(self, x, y, physics, img_channels): dec_large = self.decoding_conv(x, img_channels) # go from shape = (B, C, H, W) to (B, 64, 64, 64) (independent of modality) dec = self.relu_decoding(dec_large) Adec = physics.A(dec) grad = physics.A_adjoint(self.gain_gradx ** 2 * Adec - self.gain_grady ** 2 * y) # TODO: check if we need to have L2 (depending on noise nature, can be automated) if 'tomography' in physics.__class__.__name__.lower(): # or 'pansharp' in physics.__class__.__name__.lower(): pinv = physics.prox_l2(dec, self.gain_pinvx ** 2 * Adec - self.gain_pinvy ** 2 * y, gamma=1e9) else: pinv = physics.A_dagger(self.gain_pinvx ** 2 * Adec - self.gain_pinvy ** 2 * y) # TODO: do we set this to gain_gradx ? To get 0 during training too?? Better for denoising I guess # Mix grad and pinv emb = grad - pinv # will be 0 in the case of denoising, but also inpainting im_emb = dec - physics.A_adjoint_A(dec) # will be 0 in the case of denoising, but not inpainting # TODO: add gains here too grad_large = emb + im_emb emb_grad = self.encoding_conv(grad_large) emb_grad = self.relu_encoding(emb_grad) return emb_grad def measurement_conditioning_config_E(self, x, y, physics, img_channels, scale=1): dec = self.decoding_conv(x, img_channels) # go from shape = (B, C, H, W) to (B, 64, 64, 64) (independent of modality) physics.set_scale(scale) # TODO: check things are batched f = physics.factor if hasattr(physics, "factor") else 1.0 err = (physics.A_adjoint(physics.A(dec) - y)) # snr = self.snr_module(err) snr = dec.reshape(dec.shape[0], -1).abs().mean(dim=1) / (err.reshape(err.shape[0], -1).abs().mean(dim=1) + 1e-4) gamma = 1 / (1e-4 + 1 / (snr * f ** 2 + 1)) # TODO: check square-root / mean / check if we need to add a factor in front gamma_est = gamma[(...,) + (None,) * (dec.dim() - 1)] prox = physics.prox_l2(dec, y, gamma=gamma_est * self.fact_prox) emb = self.fact_prox_skip_1 * prox + self.fact_prox_skip_2 * dec emb_grad = self.encoding_conv(emb) emb_grad = self.relu_encoding(emb_grad) return emb_grad class ResBlock(nn.Module): def __init__( self, in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=True, mode="CRC", negative_slope=0.2, embedding=False, emb_channels=None, emb_physics=False, img_channels=None, decode_upscale=None, config = 'A', head=False, tail=False, N=4, c_mult=1, depth_encoding=1, relu_in_encoding=False, skip_in_encoding=True, ): super(ResBlock, self).__init__() if not head and not tail: assert in_channels == out_channels, "Only support in_channels==out_channels." self.separate_head = isinstance(img_channels, list) self.config = config self.is_head = head self.is_tail = tail if self.is_head: self.head = InHead(img_channels, out_channels, input_layer=True) # if self.is_tail: # self.tail = OutTail(in_channels, out_channels) if not self.is_head and not self.is_tail: self.conv1 = conv( in_channels, out_channels, kernel_size, stride, padding, bias, "C", negative_slope, ) self.nl = nn.ReLU(inplace=True) self.conv2 = conv( out_channels, out_channels, kernel_size, stride, padding, bias, "C", negative_slope, ) if embedding: self.gain = torch.nn.Parameter(torch.tensor([1.0]), requires_grad=True) self.emb_linear = MPConv(emb_channels, out_channels, kernel=[]) self.emb_physics = emb_physics if self.emb_physics: self.gain = torch.nn.Parameter(torch.tensor([1.0]), requires_grad=True) self.PhysicsBlock = MeasCondBlock(out_channels=out_channels, config=config, c_mult=c_mult, img_channels=img_channels, decode_upscale=decode_upscale, N=N, depth_encoding=depth_encoding, relu_in_encoding=relu_in_encoding, skip_in_encoding=skip_in_encoding) def forward(self, x, emb_sigma=None, physics=None, t=None, y=None, emb_in=None, img_channels=None, scale=0): u = self.conv1(x) u = self.nl(u) u_2 = self.conv2(u) # Should we sum this with below? if self.emb_physics: # TODO: add a factor (1+gain) to the emb_meas? that depends on the input snr emb_grad = self.PhysicsBlock(u, y, physics, t, img_channels=img_channels, scale=scale) u_1 = self.gain * emb_grad # x - grad (sign does not matter) else: u_1 = 0 return x + u_2 + u_1 def calculate_fan_in_and_fan_out(tensor, pytorch_style: bool = True): """ from https://github.com/megvii-research/basecls/blob/main/basecls/layers/wrapper.py#L77 """ if len(tensor.shape) not in (2, 4, 5): raise ValueError( "fan_in and fan_out can only be computed for tensor with 2/4/5 " "dimensions" ) if len(tensor.shape) == 5: # `GOIKK` to `OIKK` tensor = tensor.reshape(-1, *tensor.shape[2:]) if pytorch_style else tensor[0] num_input_fmaps = tensor.shape[1] num_output_fmaps = tensor.shape[0] receptive_field_size = 1 if len(tensor.shape) > 2: receptive_field_size = functools.reduce(lambda x, y: x * y, tensor.shape[2:], 1) fan_in = num_input_fmaps * receptive_field_size fan_out = num_output_fmaps * receptive_field_size return fan_in, fan_out def weights_init_unext(m, gain_conv=1.0, gain_linear=1.0, init_type="ortho"): if hasattr(m, "modules"): for submodule in m.modules(): if not 'skip' in str(submodule): if isinstance(submodule, nn.Conv2d) or isinstance( submodule, nn.ConvTranspose2d ): # nn.init.orthogonal_(submodule.weight.data, gain=1.0) k_shape = submodule.weight.data.shape[-1] if k_shape < 4: nn.init.orthogonal_(submodule.weight.data, gain=0.2) else: _, fan_out = calculate_fan_in_and_fan_out(submodule.weight) std = math.sqrt(2 / fan_out) nn.init.normal_(submodule.weight, 0, std) # if init_type == 'ortho': # nn.init.orthogonal_(submodule.weight.data, gain=gain_conv) # elif init_type == 'kaiming': # nn.init.kaiming_normal_(submodule.weight.data, a=0, mode='fan_in') # elif init_type == 'xavier': # nn.init.xavier_normal_(submodule.weight.data, gain=gain_conv) elif isinstance(submodule, nn.Linear): nn.init.normal_(submodule.weight.data, std=0.01) elif 'skip' in str(submodule): if isinstance(submodule, nn.Conv2d) or isinstance( submodule, nn.ConvTranspose2d ): nn.init.ones_(submodule.weight.data) # else: # classname = submodule.__class__.__name__ # # print('WARNING: no init for ', classname) def old2new(old_key): """ Converting old DRUNet keys to new UNExt style keys. PATTERNS TO MATCH: 1. Case of downsampling blocks: - for residual blocks (non-downsampling): m_down3.2.res.0.weight -> m_down3.enc.2.conv1.weight - for downsampling blocks: m_down3.4.weight -> m_down3.downsample_strideconv.weight 2. Case of upsampling blocks: - for upsampling: m_up3.0.weight -> m_up3.upsample_convtranspose.weight - for residual blocks: m_up3.2.res.0.weight -> m_up3.enc.2.conv1.weight 3. Case for body blocks: m_body.0.res.2.weight -> m_body.enc.0.conv2.weight Args: old_key (str): The old key from the state dictionary. Returns: str or None: The new key if matched, otherwise None. """ # Match keys with the pattern for residual blocks (downsampling) match_residual = re.search(r"(m_down\d+)\.(\d+)\.res\.(\d+)", old_key) if match_residual: prefix = match_residual.group(1) # e.g., "m_down2" index = match_residual.group(2) # e.g., "3" conv_index = int(match_residual.group(3)) # e.g., "0" # Determine the new conv index: 0 -> 1, 2 -> 2 new_conv_index = 1 if conv_index == 0 else 2 # Construct the new key new_key = f"{prefix}.enc.{index}.conv{new_conv_index}.weight" return new_key match_residual = re.search(r"(m_up\d+)\.(\d+)\.res\.(\d+)", old_key) if match_residual: prefix = match_residual.group(1) # e.g., "m_down2" index = int(match_residual.group(2)) # e.g., "3" conv_index = int(match_residual.group(3)) # e.g., "0" # Determine the new conv index: 0 -> 1, 2 -> 2 new_conv_index = 1 if conv_index == 0 else 2 # Construct the new key new_key = f"{prefix}.enc.{index-1}.conv{new_conv_index}.weight" return new_key match_pool_downsample = re.search(r"m_down(\d+)\.4\.weight", old_key) if match_pool_downsample: index = match_pool_downsample.group(1) # e.g., "1" or "2" # Construct the new key new_key = f"pool{index}.weight" return new_key # Match keys for upsampling blocks match_upsample = re.search(r"m_up(\d+)\.0\.weight", old_key) if match_upsample: index = match_upsample.group(1) # e.g., "1" or "2" # Construct the new key new_key = f"up{index}.weight" return new_key # Match keys for body blocks match_body = re.search(r"(m_body)\.(\d+)\.res\.(\d+)\.weight", old_key) if match_body: prefix = match_body.group(1) # e.g., "m_body" index = match_body.group(2) # e.g., "0" conv_index = int(match_body.group(3)) # e.g., "2" new_convindex = 1 if conv_index == 0 else 2 # Construct the new key new_key = f"{prefix}.enc.{index}.conv{new_convindex}.weight" return new_key # If no patterns match, return None return None def update_keyvals_headtail(old_key, old_value, init_value=None, new_key_name='m_head.conv0.weight', conditioning='base'): """ Converting old DRUNet keys to new UNExt style keys. KEYS do not change but weight need to be 0 padded. Args: old_key (str): The old key from the state dictionary. """ if 'head' in old_key: if conditioning == 'base': c_in = init_value.shape[1] c_in_old = old_value.shape[1] # if c_in == c_in_old: # new_value = old_value.detach() # elif c_in < c_in_old: # new_value = torch.zeros_like(init_value.detach()) # new_value[:, -1:, ...] = old_value[:, -1:, ...] # new_value[:, :c_in-1, ...] = old_value[:, :c_in-1, ...] # if c_in == c_in_old: # new_value = old_value.detach() # elif c_in < c_in_old: new_value = torch.zeros_like(init_value.detach()) new_value[:, -2:-1, ...] = old_value[:, -1:, ...] new_value[:, -1:, ...] = old_value[:, -1:, ...] new_value[:, :c_in-2, ...] = old_value[:, :c_in-2, ...] return {new_key_name: new_value} else: c_in = init_value.shape[1] c_in_old = old_value.shape[1] # if c_in == c_in_old - 1: # new_value = old_value[:, :-1, ...].detach() # elif c_in < c_in_old - 1: # new_value = torch.zeros_like(init_value.detach()) # new_value[:, -1:, ...] = old_value[:, -1:, ...] # new_value[:, ...] = old_value[:, :c_in, ...] new_value = torch.zeros_like(init_value.detach()) new_value[:, -1:-2, ...] = old_value[:, -1:, ...] new_value[:, -1:, ...] = old_value[:, -1:, ...] new_value[:, ...] = old_value[:, :c_in, ...] return {new_key_name: new_value} elif 'tail' in old_key: c_in = init_value.shape[0] c_in_old = old_value.shape[0] new_value = torch.zeros_like(init_value.detach()) if c_in == c_in_old: new_value = old_value.detach() elif c_in < c_in_old: new_value = torch.zeros_like(init_value.detach()) new_value[:, ...] = old_value[:c_in, ...] return {new_key_name: new_value} else: print(f"Key {old_key} does not contain 'head' or 'tail'.") # test the network if __name__ == "__main__": net = UNeXt() x = torch.randn(1, 3, 128, 128) y = net(x, 0.1) # print(y.shape) # print(y) # Case for diagonal physics # IDEA 1: kills signal in the image of A # im_emb = dec - physics.A_adjoint_A(dec) # will be 0 in the case of denoising, but not inpainting # TODO: add gains here too # IDEA 2: compute norm of signal in ker of A # normker = (dec - physics.A_adjoint_A(dec)).norm() / (dec.norm() + 1e-4) # im_emb = normker * physics.A_adjoint(self.gain_diag_x * physics.A(dec) - self.gain_diag_y * y) # will be 0 in the case of denoising, but not inpainting # TODO: add gains here too # IDEA 3: same as above but add the pinv as well # normker = (dec - physics.A_adjoint_A(dec)).norm() / (dec.norm() + 1e-4) # grad_term = physics.A_adjoint(self.gain_diag_x * physics.A(dec) - self.gain_diag_y * y) # # pinv_term = physics.A_dagger(self.gain_diagpinv_x * physics.A(dec) - self.gain_diagpinv_y * y) # if 'tomography' in physics.__class__.__name__.lower(): # or 'pansharp' in physics.__class__.__name__.lower(): # pinv_term = physics.prox_l2(dec, self.gain_diagpinv_x ** 2 * Adec - self.gain_diagpinv_y ** 2 * y, gamma=1e9) # else: # pinv_term = physics.A_dagger(self.gain_diagpinv_x ** 2 * Adec - self.gain_diagpinv_y ** 2 * y) # TODO: do we set this to gain_gradx ? To get 0 during training too?? Better for denoising I guess # im_emb = normker * (grad_term + pinv_term) # will be 0 in the case of denoising, but not inpainting # TODO: add gains here too # # Mix it # if hasattr(physics.noise_model, 'sigma'): # sigma = physics.noise_model.sigma # SNR ? x /= sigma ** 2 # snr = (y.abs().mean()) / (sigma + 1e-4) # SNR equivariant # TODO: add epsilon # snr = snr[(...,) + (None,) * (im_emb.dim() - 1)] # else: # snr = 1e4 # # grad_large = emb + self.gain_diag * (1 + self.gain_noise / snr) * im_emb