denoising / models /unext_wip.py
Yonuts's picture
gradio demo
12a4d59
# Code borrowed from Kai Zhang https://github.com/cszn/DPIR/tree/master/models
import re
import math
import functools
import deepinv as dinv
from deepinv.utils import plot, TensorList
import torch
from torch.func import vmap
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from deepinv.optim.utils import conjugate_gradient
from physics.multiscale import MultiScaleLinearPhysics, Pad
from models.blocks import EquivMaxPool, AffineConv2d, ConvNextBlock2, NoiseEmbedding, MPConv, TimestepEmbedding, conv, downsample_strideconv, upsample_convtranspose
from models.heads import Heads, Tails, InHead, OutTail, ConvChannels, SNRModule, EquivConvModule, EquivHeads
cuda = True if torch.cuda.is_available() else False
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor
### --------------- MODEL ---------------
class BaseEncBlock(nn.Module):
def __init__(
self,
in_channels,
out_channels,
bias=False,
mode="CRC",
nb=2,
embedding=False,
emb_channels=None,
emb_physics=False,
img_channels=None,
decode_upscale=None,
config='A',
N=4,
c_mult=1,
depth_encoding=1,
relu_in_encoding=False,
skip_in_encoding=True,
):
super(BaseEncBlock, self).__init__()
self.config = config
self.enc = nn.ModuleList(
[
ResBlock(
in_channels,
out_channels,
bias=bias,
mode=mode,
embedding=embedding,
emb_channels=emb_channels,
emb_physics=emb_physics,
img_channels=img_channels,
decode_upscale=decode_upscale,
config=config,
N=N,
c_mult=c_mult,
depth_encoding=depth_encoding,
relu_in_encoding=relu_in_encoding,
skip_in_encoding=skip_in_encoding,
)
for _ in range(nb)
]
)
def forward(self, x, emb_sigma=None, physics=None, t=None, y=None, emb_in=None, img_channels=None, scale=0):
for i in range(len(self.enc)):
x = self.enc[i](x, emb_sigma=emb_sigma, physics=physics, t=t, y=y, img_channels=img_channels, scale=scale)
return x
class NextEncBlock(nn.Module):
def __init__(
self, in_channels, out_channels, bias=False, mode="", mult_fact=4, nb=2
):
super(NextEncBlock, self).__init__()
self.enc = nn.ModuleList(
[
ConvNextBlock2(
in_channels=in_channels,
out_channels=out_channels,
bias=bias,
mode=mode,
mult_fact=mult_fact,
)
for _ in range(nb)
]
)
def forward(self, x, emb_sigma=None):
for i in range(len(self.enc)):
x = self.enc[i](x, emb_sigma)
return x
class UNeXt(nn.Module):
r"""
DRUNet denoiser network.
The network architecture is based on the paper
`Learning deep CNN denoiser prior for image restoration <https://arxiv.org/abs/1704.03264>`_,
and has a U-Net like structure, with convolutional blocks in the encoder and decoder parts.
The network takes into account the noise level of the input image, which is encoded as an additional input channel.
A pretrained network for (in_channels=out_channels=1 or in_channels=out_channels=3)
can be downloaded via setting ``pretrained='download'``.
:param int in_channels: number of channels of the input.
:param int out_channels: number of channels of the output.
:param list nc: number of convolutional layers.
:param int nb: number of convolutional blocks per layer.
:param int nf: number of channels per convolutional layer.
:param str act_mode: activation mode, "R" for ReLU, "L" for LeakyReLU "E" for ELU and "S" for Softplus.
:param str downsample_mode: Downsampling mode, "avgpool" for average pooling, "maxpool" for max pooling, and
"strideconv" for convolution with stride 2.
:param str upsample_mode: Upsampling mode, "convtranspose" for convolution transpose, "pixelsuffle" for pixel
shuffling, and "upconv" for nearest neighbour upsampling with additional convolution.
:param str, None pretrained: use a pretrained network. If ``pretrained=None``, the weights will be initialized at random
using Pytorch's default initialization. If ``pretrained='download'``, the weights will be downloaded from an
online repository (only available for the default architecture with 3 or 1 input/output channels).
Finally, ``pretrained`` can also be set as a path to the user's own pretrained weights.
See :ref:`pretrained-weights <pretrained-weights>` for more details.
:param bool train: training or testing mode.
:param str device: gpu or cpu.
"""
def __init__(
self,
in_channels=[1, 2, 3],
out_channels=[1, 2, 3],
nc=[64, 128, 256, 512],
nb=4, # 4 in DRUNet but out of memory
conv_type="next", # should be 'base' or 'next'
pool_type="next", # should be 'base' or 'next'
cond_type="base", # conditioning, should be 'base' or 'edm'
device=None,
bias=False,
mode="",
residual=False,
act_mode="R",
layer_scale_init_value=1e-6,
init_type="ortho",
gain_init_conv=1.0,
gain_init_linear=1.0,
drop_prob=0.0,
replk=False,
mult_fact=4,
antialias="gaussian",
emb_physics=False,
config='A',
pretrained_pth=None,
N=4,
c_mult=1,
depth_encoding=1,
relu_in_encoding=False,
skip_in_encoding=True,
):
super(UNeXt, self).__init__()
self.residual = residual
self.conv_type = conv_type
self.pool_type = pool_type
self.emb_physics = emb_physics
self.config = config
self.in_channels = in_channels
self.fact_realign = torch.nn.Parameter(torch.tensor([1.0], device=device))
self.separate_head = isinstance(in_channels, list)
assert cond_type in ["base", "edm"], "cond_type should be 'base' or 'edm'"
self.cond_type = cond_type
if self.cond_type == "base":
if self.config != 'E':
if isinstance(in_channels, list):
in_channels_first = []
for i in range(len(in_channels)):
in_channels_first.append(in_channels[i] + 2)
else: # old head
in_channels_first = in_channels + 1
else:
in_channels_first = in_channels
else:
in_channels_first = in_channels
self.noise_embedding = NoiseEmbedding(
num_channels=in_channels, emb_channels=max(nc), device=device
)
self.timestep_embedding = lambda x: x
# check if in_channels is a list
self.m_head = InHead(in_channels_first, nc[0])
if conv_type == "next":
self.m_down1 = NextEncBlock(
nc[0], nc[0], bias=bias, mode=mode, mult_fact=mult_fact, nb=nb
)
self.m_down2 = NextEncBlock(
nc[1], nc[1], bias=bias, mode=mode, mult_fact=mult_fact, nb=nb
)
self.m_down3 = NextEncBlock(
nc[2], nc[2], bias=bias, mode=mode, mult_fact=mult_fact, nb=nb
)
self.m_body = NextEncBlock(
nc[3], nc[3], bias=bias, mode=mode, mult_fact=mult_fact, nb=nb
)
self.m_up3 = NextEncBlock(
nc[2], nc[2], bias=bias, mode=mode, mult_fact=mult_fact, nb=nb
)
self.m_up2 = NextEncBlock(
nc[1], nc[1], bias=bias, mode=mode, mult_fact=mult_fact, nb=nb
)
self.m_up1 = NextEncBlock(
nc[0], nc[0], bias=bias, mode=mode, mult_fact=mult_fact, nb=nb
)
elif conv_type == "base":
embedding = (
False if cond_type == "base" else True
)
emb_channels = max(nc)
self.m_down1 = BaseEncBlock(
nc[0],
nc[0],
bias=False,
mode="CRC",
nb=nb,
embedding=embedding,
emb_channels=emb_channels,
emb_physics=emb_physics,
img_channels=in_channels,
decode_upscale=1,
config=config,
N=N,
c_mult=c_mult,
depth_encoding=depth_encoding,
relu_in_encoding=relu_in_encoding,
skip_in_encoding=skip_in_encoding,
)
self.m_down2 = BaseEncBlock(
nc[1],
nc[1],
bias=False,
mode="CRC",
nb=nb,
embedding=embedding,
emb_channels=emb_channels,
emb_physics=emb_physics,
img_channels=in_channels,
decode_upscale=2,
config=config,
N=N,
c_mult=c_mult,
depth_encoding=depth_encoding,
relu_in_encoding=relu_in_encoding,
skip_in_encoding=skip_in_encoding,
)
self.m_down3 = BaseEncBlock(
nc[2],
nc[2],
bias=False,
mode="CRC",
nb=nb,
embedding=embedding,
emb_channels=emb_channels,
emb_physics=emb_physics,
img_channels=in_channels,
decode_upscale=4,
config=config,
N=N,
c_mult=c_mult,
depth_encoding=depth_encoding,
relu_in_encoding=relu_in_encoding,
skip_in_encoding=skip_in_encoding,
)
self.m_body = BaseEncBlock(
nc[3],
nc[3],
bias=False,
mode="CRC",
nb=nb,
embedding=embedding,
emb_channels=emb_channels,
emb_physics=emb_physics,
img_channels=in_channels,
decode_upscale=8,
config=config,
N=N,
c_mult=c_mult,
depth_encoding=depth_encoding,
relu_in_encoding=relu_in_encoding,
skip_in_encoding=skip_in_encoding,
)
self.m_up3 = BaseEncBlock(
nc[2],
nc[2],
bias=False,
mode="CRC",
nb=nb,
embedding=embedding,
emb_channels=emb_channels,
emb_physics=emb_physics,
img_channels=in_channels,
decode_upscale=4,
config=config,
N=N,
c_mult=c_mult,
depth_encoding=depth_encoding,
relu_in_encoding=relu_in_encoding,
skip_in_encoding=skip_in_encoding,
)
self.m_up2 = BaseEncBlock(
nc[1],
nc[1],
bias=False,
mode="CRC",
nb=nb,
embedding=embedding,
emb_channels=emb_channels,
emb_physics=emb_physics,
img_channels=in_channels,
decode_upscale=2,
config=config,
N=N,
c_mult=c_mult,
depth_encoding=depth_encoding,
relu_in_encoding=relu_in_encoding,
skip_in_encoding=skip_in_encoding,
)
self.m_up1 = BaseEncBlock(
nc[0],
nc[0],
bias=False,
mode="CRC",
nb=nb,
embedding=embedding,
emb_channels=emb_channels,
emb_physics=emb_physics,
img_channels=in_channels,
decode_upscale=1,
config=config,
N=N,
c_mult=c_mult,
depth_encoding=depth_encoding,
relu_in_encoding=relu_in_encoding,
skip_in_encoding=skip_in_encoding,
)
else:
raise NotImplementedError("conv_type should be 'base' or 'next'")
if pool_type == "next_max":
self.pool1 = EquivMaxPool(
antialias=antialias,
in_channels=nc[0],
out_channels=nc[1],
device=device,
)
self.pool2 = EquivMaxPool(
antialias=antialias,
in_channels=nc[1],
out_channels=nc[2],
device=device,
)
self.pool3 = EquivMaxPool(
antialias=antialias,
in_channels=nc[2],
out_channels=nc[3],
device=device,
)
elif pool_type == "base":
self.pool1 = downsample_strideconv(nc[0], nc[1], bias=False, mode="2")
self.pool2 = downsample_strideconv(nc[1], nc[2], bias=False, mode="2")
self.pool3 = downsample_strideconv(nc[2], nc[3], bias=False, mode="2")
self.up3 = upsample_convtranspose(nc[3], nc[2], bias=False, mode="2")
self.up2 = upsample_convtranspose(nc[2], nc[1], bias=False, mode="2")
self.up1 = upsample_convtranspose(nc[1], nc[0], bias=False, mode="2")
else:
raise NotImplementedError("pool_type should be 'base' or 'next'")
self.m_tail = OutTail(nc[0], in_channels)
if conv_type == "base":
init_func = functools.partial(
weights_init_unext, init_type="ortho", gain_conv=0.2
)
self.apply(init_func)
else:
init_func = functools.partial(
weights_init_unext,
init_type=init_type,
gain_conv=gain_init_conv,
gain_linear=gain_init_linear,
)
self.apply(init_func)
if pretrained_pth=='jz':
pth = '/lustre/fswork/projects/rech/nyd/commun/mterris/base_checkpoints/drunet_deepinv_color_finetune_22k.pth'
self.load_drunet_weights(pth)
elif pretrained_pth is not None:
self.load_drunet_weights(pretrained_pth)
if self.config == 'D':
# deactivate grad for layers that do not contain the string "PhysicsBlock" or "gain" or "fact_realign"
for name, param in self.named_parameters():
if 'PhysicsBlock' not in name and 'gain' not in name and 'fact_realign' not in name and "m_head" not in name and "m_tail" not in name:
param.requires_grad = False
if device is not None:
self.to(device)
def load_drunet_weights(self, ckpt_pth):
state_dict = torch.load(ckpt_pth, map_location=lambda storage, loc: storage)
new_state_dict = {}
matched_keys = [] # List to store successfully matched keys
unmatched_keys = [] # List to store keys that were not matched or excluded
excluded_keys = [] # List to store excluded keys
# Define patterns to exclude
exclude_patterns = ["head", "tail"]
# Dealing with regular keys
for old_key, value in state_dict.items():
# Skip keys containing any of the excluded patterns
if any(excluded in old_key for excluded in exclude_patterns):
excluded_keys.append(old_key)
continue # Skip further processing for this key
new_key = old2new(old_key)
if new_key is not None:
matched_keys.append((old_key, new_key)) # Record the matched keys
new_state_dict[new_key] = value
else:
unmatched_keys.append(old_key) # Record unmatched keys
# TODO: clean this
for excluded_key in excluded_keys:
if isinstance(self.in_channels, list):
for i, in_channel in enumerate(self.in_channels):
# print('Dealing with conv ', i)
new_key = f"m_head.conv{i}.weight"
if 'head' in excluded_key:
new_key = f"m_head.conv{i}.weight"
# new_key = f"m_head.head.conv{i}.weight"
if 'tail' in excluded_key:
new_key = f"m_tail.conv{i}.weight"
# DEBUG print all keys of state dict:
# print(state_dict.keys())
# print(self.state_dict().keys())
conditioning = 'base'
# if self.config == 'E':
# conditioning = False
new_kv = update_keyvals_headtail(excluded_key,
state_dict[excluded_key],
init_value=self.state_dict()[new_key],
new_key_name=new_key,
conditioning=conditioning)
new_state_dict.update(new_kv)
# print(new_kv.keys())
else:
new_kv = update_keyvals_headtail(excluded_key, state_dict[excluded_key])
new_state_dict.update(new_kv)
# Display matched keys
print("Matched keys:")
for old_key, new_key in matched_keys:
print(f"{old_key} -> {new_key}")
# Load updated state dict into the model
self.load_state_dict(new_state_dict, strict=False)
# Display unmatched keys
print("\nUnmatched keys:")
for unmatched_key in unmatched_keys:
print(unmatched_key)
print("Weights loaded from ", ckpt_pth)
def constant2map(self, value, x):
if isinstance(value, torch.Tensor):
if value.ndim > 0:
value_map = value.view(x.size(0), 1, 1, 1)
value_map = value_map.expand(-1, 1, x.size(2), x.size(3))
else:
value_map = torch.ones(
(x.size(0), 1, x.size(2), x.size(3)), device=x.device
) * value[None, None, None, None].to(x.device)
else:
value_map = (
torch.ones((x.size(0), 1, x.size(2), x.size(3)), device=x.device)
* value
)
return value_map
def base_conditioning(self, x, sigma, gamma):
noise_level_map = self.constant2map(sigma, x)
gamma_map = self.constant2map(gamma, x)
return torch.cat((x, noise_level_map, gamma_map), 1)
def realign_input(self, x, physics, y):
if hasattr(physics, "factor"):
f = physics.factor
elif hasattr(physics, "base") and hasattr(physics.base, "factor"):
f = physics.base.factor
elif hasattr(physics, "base") and hasattr(physics.base, "base") and hasattr(physics.base.base, "factor"):
f = physics.base.base.factor
else:
f = 1.0
sigma = 1e-6 # default value
if hasattr(physics.noise_model, 'sigma'):
sigma = physics.noise_model.sigma
if hasattr(physics, 'base') and hasattr(physics.base, 'noise_model') and hasattr(physics.base.noise_model, 'sigma'):
sigma = physics.base.noise_model.sigma
if hasattr(physics, 'base') and hasattr(physics.base, 'base') and hasattr(physics.base.base, 'noise_model') and hasattr(physics.base.base.noise_model, 'sigma'):
sigma = physics.base.base.noise_model.sigma
if isinstance(y, TensorList):
num = (y[0].reshape(y[0].shape[0], -1).abs().mean(1))
else:
num = (y.reshape(y.shape[0], -1).abs().mean(1))
snr = num / (sigma + 1e-4) # SNR equivariant
gamma = 1 / (1e-4 + 1 / (snr * f **2 )) # TODO: check square-root / mean / check if we need to add a factor in front ?
gamma = gamma[(...,) + (None,) * (x.dim() - 1)]
model_input = physics.prox_l2(x, y, gamma=gamma * self.fact_realign)
return model_input
def forward_unet(self, x0, sigma=None, gamma=None, physics=None, t=None, y=None, img_channels=None):
# list_values = []
if self.cond_type == "base":
# if self.config != 'E':
x0 = self.base_conditioning(x0, sigma, gamma)
emb_sigma = None
else:
emb_sigma = self.noise_embedding(
sigma
) # This only if the embedding is the non-basic one from drunet
emb_timestep = self.timestep_embedding(t)
x1 = self.m_head(x0) # old
# x1 = self.m_head(x0, emb_sigma=emb_sigma, physics=physics, t=emb_timestep, y=y, img_channels=img_channels)
# list_values.append(x1.abs().mean())
if self.config == 'G':
x1_, emb1_ = self.m_down1(x1, emb_sigma=emb_sigma, physics=physics, t=emb_timestep, y=y, img_channels=img_channels)
else:
x1_ = self.m_down1(x1, emb_sigma=emb_sigma, physics=physics, t=emb_timestep, y=y, img_channels=img_channels, scale=0)
x2 = self.pool1(x1_)
# list_values.append(x2.abs().mean())
if self.config == 'G':
x3_, emb3_ = self.m_down2(x2, emb_sigma=emb_sigma, physics=physics, t=emb_timestep, y=y, img_channels=img_channels)
else:
x3_ = self.m_down2(x2, emb_sigma=emb_sigma, physics=physics, t=emb_timestep, y=y, img_channels=img_channels, scale=1)
x3 = self.pool2(x3_)
# list_values.append(x3.abs().mean())
if self.config == 'G':
x4_, emb4_ = self.m_down3(x3, emb_sigma=emb_sigma, physics=physics, t=emb_timestep, y=y, img_channels=img_channels)
else:
x4_ = self.m_down3(x3, emb_sigma=emb_sigma, physics=physics, t=emb_timestep, y=y, img_channels=img_channels, scale=2)
x4 = self.pool3(x4_)
# issue: https://github.com/matthieutrs/ram_project/issues/1
# solution 1: using .contiguous() below
# solution 2: using a print statement that magically solves the issue
###print(x4.is_contiguous())
# list_values.append(x4.abs().mean())
if self.config == 'G':
x, _ = self.m_body(x4, emb_sigma=emb_sigma, physics=physics, t=emb_timestep, y=y, img_channels=img_channels)
else:
x = self.m_body(x4, emb_sigma=emb_sigma, physics=physics, t=emb_timestep, y=y, img_channels=img_channels, scale=3)
# list_values.append(x.abs().mean())
if self.pool_type == "next" or self.pool_type == "next_max":
x = self.pool3.upscale(x + x4)
else:
x = self.up3(x + x4)
if self.config == 'G':
x, _ = self.m_up3(x, emb_sigma=emb_sigma, physics=physics, t=emb_timestep, y=y, emb_in=emb4_, img_channels=img_channels)
else:
x = self.m_up3(x, emb_sigma=emb_sigma, physics=physics, t=emb_timestep, y=y, img_channels=img_channels, scale=2)
# list_values.append(x.abs().mean())
if self.pool_type == "next" or self.pool_type == "next_max":
x = self.pool2.upscale(x + x3)
else:
x = self.up2(x + x3)
if self.config == 'G':
x, _ = self.m_up2(x, emb_sigma=emb_sigma, physics=physics, t=emb_timestep, y=y, emb_in=emb3_, img_channels=img_channels)
else:
x = self.m_up2(x, emb_sigma=emb_sigma, physics=physics, t=emb_timestep, y=y, img_channels=img_channels, scale=1)
# list_values.append(x.abs().mean())
if self.pool_type == "next" or self.pool_type == "next_max":
x = self.pool1.upscale(x + x2)
else:
x = self.up1(x + x2)
if self.config == 'G':
x, _ = self.m_up1(x, emb_sigma=emb_sigma, physics=physics, t=emb_timestep, y=y, emb_in=emb1_, img_channels=img_channels)
else:
x = self.m_up1(x, emb_sigma=emb_sigma, physics=physics, t=emb_timestep, y=y, img_channels=img_channels, scale=0)
# list_values.append(x.abs().mean())
if self.separate_head:
x = self.m_tail(x + x1, img_channels)
else:
x = self.m_tail(x + x1)
return x
def forward(self, x, sigma=None, gamma=None, physics=None, t=None, y=None):
r"""
Run the denoiser on image with noise level :math:`\sigma`.
:param torch.Tensor x: noisy image
:param float, torch.Tensor sigma: noise level. If ``sigma`` is a float, it is used for all images in the batch.
If ``sigma`` is a tensor, it must be of shape ``(batch_size,)``.
"""
img_channels = x.shape[1] # x_n_chan = x.shape[1]
if self.emb_physics:
physics = MultiScaleLinearPhysics(physics, x.shape[-3:], device=x.device)
if self.separate_head and img_channels not in self.in_channels:
raise ValueError(f"Input image has {img_channels} channels, but the network only have heads for {self.in_channels} channels.")
if y is not None:
x = self.realign_input(x, physics, y)
x = self.forward_unet(x, sigma=sigma, gamma=gamma, physics=physics, t=t, y=y, img_channels=img_channels)
return x
def krylov_embeddings_old(y, p, factor, v=None, N=4, feat_size=1, x_init=None, img_channels=3):
if x_init is None:
x = p.A_adjoint(y)
else:
x = x_init[:, :img_channels, ...]
if feat_size > 1:
_, C, _, _ = x.shape
if v is None:
v = torch.zeros_like(x).repeat(1, N-1, 1, 1)
out = x - v[:, :C, ...]
norm = factor ** 2
A = lambda u: p.A_adjoint(p.A(u)) * norm
for i in range(N-1):
x = A(x) - v[:, (i+1) * C:(i+2) * C, ...]
out = torch.cat([out, x], dim=1)
else:
if v is None:
v = torch.zeros_like(x)
out = x - v
norm = factor ** 2
A = lambda u: p.A_adjoint(p.A(u)) * norm
for i in range(N-1):
x = A(x) - v
out = torch.cat([out, x], dim=1)
return out
def krylov_embeddings(y, p, factor, v=None, N=4, x_init=None, img_channels=3):
"""
Efficient Krylov subspace embedding computation with parallel processing.
Args:
y (torch.Tensor): The input tensor.
p: An object with A and A_adjoint methods (linear operator).
factor (float): Scaling factor.
v (torch.Tensor, optional): Precomputed values to subtract from Krylov sequence. Defaults to None.
N (int, optional): Number of Krylov iterations. Defaults to 4.
feat_size (int, optional): Feature expansion size. Defaults to 1.
x_init (torch.Tensor, optional): Initial guess. Defaults to None.
img_channels (int, optional): Number of image channels. Defaults to 3.
Returns:
torch.Tensor: The Krylov embeddings.
"""
if x_init is None:
x = p.A_adjoint(y)
else:
x = x_init.clone() # Extract the first img_channels
norm = factor ** 2 # Precompute normalization factor
AtA = lambda u: p.A_adjoint(p.A(u)) * norm # Define the linear operator
v = v if v is not None else torch.zeros_like(x)
out = x.clone()
# Compute Krylov basis
x_k = x.clone()
for i in range(N-1):
x_k = AtA(x_k) - v
out = torch.cat([out, x_k], dim=1)
return out
def grad_embeddings(y, p, factor, v=None, N=4, feat_size=1):
Aty = p.A_adjoint(y)
if feat_size > 1:
_, C, _, _ = Aty.shape
if v is None:
v = torch.zeros_like(Aty).repeat(1, N-1, 1, 1)
out = v[:, :C, ...] - Aty
norm = factor ** 2
A = lambda u: p.A_adjoint(p.A(u)) * norm
for i in range(N-1):
x = A(v[:, (i+1) * C:(i+2) * C, ...]) - Aty
out = torch.cat([out, x], dim=1)
else:
if v is None:
v = torch.zeros_like(Aty)
out = v - Aty
norm = factor ** 2
A = lambda u: p.A_adjoint(p.A(u)) * norm
for i in range(N-1):
x = A(v) - Aty
out = torch.cat([out, x], dim=1)
return out
def prox_embeddings(y, p, factor, v=None, N=4):
x = p.A_adjoint(y)
B, C, H, W = x.shape
if v is None:
v = torch.zeros_like(x)
v = v.repeat(1, N - 1, 1, 1)
gamma = torch.logspace(-4, -1, N-1, device=x.device).repeat_interleave(C).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
norm = factor ** 2
A_sub = lambda u: torch.cat([p.A_adjoint(p.A(u[:, i * C:(i+1) * C, ...])) * norm for i in range(N-1)], dim=1)
A = lambda u: A_sub(u) + (u - v) * gamma
u_hat = conjugate_gradient(A, x.repeat(1, N-1, 1, 1), max_iter=3, tol=1e-3)
u_hat = torch.cat([u_hat, x], dim=1)
return u_hat
# --------------------------------------------
# Res Block: x + conv(relu(conv(x)))
# --------------------------------------------
class MeasCondBlock(nn.Module):
def __init__(
self,
out_channels=64,
img_channels=None,
decode_upscale=None,
config = 'A',
N=4,
depth_encoding=1,
relu_in_encoding=False,
skip_in_encoding=True,
c_mult=1,
):
super(MeasCondBlock, self).__init__()
self.separate_head = isinstance(img_channels, list)
self.config = config
assert img_channels is not None, "decode_dimensions should be provided"
assert decode_upscale is not None, "decode_upscale should be provided"
# if self.separate_head:
if self.config == 'A':
self.relu_encoding = nn.ReLU(inplace=False)
self.N = N
self.c_mult = c_mult
self.encoding_conv = Heads(img_channels, out_channels, depth=depth_encoding, scale=1, bias=False, c_mult=self.c_mult, relu_in=relu_in_encoding, skip_in=skip_in_encoding)
if self.config == 'B':
self.N = N
self.c_mult = c_mult
self.relu_encoding = nn.ReLU(inplace=False)
self.decoding_conv = Tails(out_channels, img_channels, depth=1, scale=1, bias=False, c_mult=self.c_mult)
self.encoding_conv = Heads(img_channels, out_channels, depth=depth_encoding, scale=1, bias=False, c_mult=self.c_mult, relu_in=relu_in_encoding, skip_in=skip_in_encoding)
if self.config == 'C':
self.N = N
self.c_mult = c_mult
self.relu_encoding = nn.ReLU(inplace=False)
self.decoding_conv = Tails(out_channels, img_channels, depth=1, scale=1, bias=False, c_mult=self.c_mult)
self.encoding_conv = Heads(img_channels, out_channels, depth=depth_encoding, scale=1, bias=False, c_mult=self.c_mult*N, c_add=N, relu_in=relu_in_encoding, skip_in=skip_in_encoding)
elif self.config == 'D':
self.N = N
self.c_mult = c_mult
self.relu_encoding = nn.ReLU(inplace=False)
self.decoding_conv = Tails(out_channels, img_channels, depth=1, scale=1, bias=False, c_mult=self.c_mult)
self.encoding_conv = Heads(img_channels, out_channels, depth=depth_encoding, scale=1, bias=False, c_mult=self.c_mult*N, c_add=N, relu_in=relu_in_encoding, skip_in=skip_in_encoding)
self.gain = torch.nn.Parameter(torch.tensor([1.0]), requires_grad=True)
self.gain_gradx = torch.nn.Parameter(torch.tensor([1e-2]), requires_grad=True)
self.gain_grady = torch.nn.Parameter(torch.tensor([1e-2]), requires_grad=True)
self.gain_pinvx = torch.nn.Parameter(torch.tensor([1e-2]), requires_grad=True)
self.gain_pinvy = torch.nn.Parameter(torch.tensor([1e-2]), requires_grad=True)
def forward(self, x, y, physics, t, emb_in=None, img_channels=None, scale=1):
if self.config == 'A':
return self.measurement_conditioning_config_A(x, y, physics, img_channels=img_channels, scale=scale)
elif self.config == 'F':
return self.measurement_conditioning_config_F(x, y, physics, img_channels=img_channels, scale=scale)
elif self.config == 'B':
return self.measurement_conditioning_config_B(x, y, physics, img_channels=img_channels, scale=scale)
elif self.config == 'C':
return self.measurement_conditioning_config_C(x, y, physics, img_channels=img_channels, scale=scale)
elif self.config == 'D':
return self.measurement_conditioning_config_D(x, y, physics, img_channels=img_channels, scale=scale)
elif self.config == 'E':
return self.measurement_conditioning_config_E(x, y, physics, img_channels=img_channels, scale=scale)
else:
raise NotImplementedError('Config not implemented')
def measurement_conditioning_config_A(self, x, y, physics, img_channels, scale=0):
physics.set_scale(scale)
factor = 2**(scale)
meas = krylov_embeddings(y, physics, factor, N=self.N, img_channels=img_channels)
cond = self.encoding_conv(meas)
emb = self.relu_encoding(cond)
return emb
def measurement_conditioning_config_B(self, x, y, physics, img_channels, scale=0):
physics.set_scale(scale)
dec = self.decoding_conv(x, img_channels)
factor = 2**(scale)
meas = krylov_embeddings(y, physics, factor, v=dec, N=self.N, img_channels=img_channels)
cond = self.encoding_conv(meas)
emb = self.relu_encoding(cond)
return emb # * sigma_emb
def measurement_conditioning_config_C(self, x, y, physics, img_channels, scale=0):
physics.set_scale(scale)
dec = self.decoding_conv(x, img_channels)
factor = 2**(scale)
meas_y = krylov_embeddings(y, physics, factor, N=self.N, img_channels=img_channels)
meas_dec = krylov_embeddings(y, physics, factor, N=self.N, x_init=dec[:, :img_channels, ...], img_channels=img_channels)
for c in range(1, self.c_mult):
meas_cur = krylov_embeddings(y, physics, factor, N=self.N, x_init=dec[:, img_channels*c:img_channels*(c+1)],
img_channels=img_channels)
meas_dec = torch.cat([meas_dec, meas_cur], dim=1)
meas = torch.cat([meas_y, meas_dec], dim=1)
cond = self.encoding_conv(meas)
emb = self.relu_encoding(cond)
return emb
def measurement_conditioning_config_D(self, x, y, physics, img_channels, scale=0):
physics.set_scale(scale)
dec = self.decoding_conv(x, img_channels)
factor = 2**(scale)
meas_y = krylov_embeddings(y, physics, factor, N=self.N, img_channels=img_channels)
meas_dec = krylov_embeddings(y, physics, factor, N=self.N, x_init=dec[:, :img_channels, ...], img_channels=img_channels)
for c in range(1, self.c_mult):
meas_cur = krylov_embeddings(y, physics, factor, N=self.N, x_init=dec[:, img_channels*c:img_channels*(c+1)],
img_channels=img_channels)
meas_dec = torch.cat([meas_dec, meas_cur], dim=1)
meas = torch.cat([meas_y, meas_dec], dim=1)
cond = self.encoding_conv(meas)
emb = self.relu_encoding(cond)
return cond
def measurement_conditioning_config_F(self, x, y, physics, img_channels):
dec_large = self.decoding_conv(x, img_channels) # go from shape = (B, C, H, W) to (B, 64, 64, 64) (independent of modality)
dec = self.relu_decoding(dec_large)
Adec = physics.A(dec)
grad = physics.A_adjoint(self.gain_gradx ** 2 * Adec - self.gain_grady ** 2 * y) # TODO: check if we need to have L2 (depending on noise nature, can be automated)
if 'tomography' in physics.__class__.__name__.lower(): # or 'pansharp' in physics.__class__.__name__.lower():
pinv = physics.prox_l2(dec, self.gain_pinvx ** 2 * Adec - self.gain_pinvy ** 2 * y, gamma=1e9)
else:
pinv = physics.A_dagger(self.gain_pinvx ** 2 * Adec - self.gain_pinvy ** 2 * y) # TODO: do we set this to gain_gradx ? To get 0 during training too?? Better for denoising I guess
# Mix grad and pinv
emb = grad - pinv # will be 0 in the case of denoising, but also inpainting
im_emb = dec - physics.A_adjoint_A(dec) # will be 0 in the case of denoising, but not inpainting # TODO: add gains here too
grad_large = emb + im_emb
emb_grad = self.encoding_conv(grad_large)
emb_grad = self.relu_encoding(emb_grad)
return emb_grad
def measurement_conditioning_config_E(self, x, y, physics, img_channels, scale=1):
dec = self.decoding_conv(x, img_channels) # go from shape = (B, C, H, W) to (B, 64, 64, 64) (independent of modality)
physics.set_scale(scale)
# TODO: check things are batched
f = physics.factor if hasattr(physics, "factor") else 1.0
err = (physics.A_adjoint(physics.A(dec) - y))
# snr = self.snr_module(err)
snr = dec.reshape(dec.shape[0], -1).abs().mean(dim=1) / (err.reshape(err.shape[0], -1).abs().mean(dim=1) + 1e-4)
gamma = 1 / (1e-4 + 1 / (snr * f ** 2 + 1)) # TODO: check square-root / mean / check if we need to add a factor in front
gamma_est = gamma[(...,) + (None,) * (dec.dim() - 1)]
prox = physics.prox_l2(dec, y, gamma=gamma_est * self.fact_prox)
emb = self.fact_prox_skip_1 * prox + self.fact_prox_skip_2 * dec
emb_grad = self.encoding_conv(emb)
emb_grad = self.relu_encoding(emb_grad)
return emb_grad
class ResBlock(nn.Module):
def __init__(
self,
in_channels=64,
out_channels=64,
kernel_size=3,
stride=1,
padding=1,
bias=True,
mode="CRC",
negative_slope=0.2,
embedding=False,
emb_channels=None,
emb_physics=False,
img_channels=None,
decode_upscale=None,
config = 'A',
head=False,
tail=False,
N=4,
c_mult=1,
depth_encoding=1,
relu_in_encoding=False,
skip_in_encoding=True,
):
super(ResBlock, self).__init__()
if not head and not tail:
assert in_channels == out_channels, "Only support in_channels==out_channels."
self.separate_head = isinstance(img_channels, list)
self.config = config
self.is_head = head
self.is_tail = tail
if self.is_head:
self.head = InHead(img_channels, out_channels, input_layer=True)
# if self.is_tail:
# self.tail = OutTail(in_channels, out_channels)
if not self.is_head and not self.is_tail:
self.conv1 = conv(
in_channels,
out_channels,
kernel_size,
stride,
padding,
bias,
"C",
negative_slope,
)
self.nl = nn.ReLU(inplace=True)
self.conv2 = conv(
out_channels,
out_channels,
kernel_size,
stride,
padding,
bias,
"C",
negative_slope,
)
if embedding:
self.gain = torch.nn.Parameter(torch.tensor([1.0]), requires_grad=True)
self.emb_linear = MPConv(emb_channels, out_channels, kernel=[])
self.emb_physics = emb_physics
if self.emb_physics:
self.gain = torch.nn.Parameter(torch.tensor([1.0]), requires_grad=True)
self.PhysicsBlock = MeasCondBlock(out_channels=out_channels, config=config, c_mult=c_mult,
img_channels=img_channels, decode_upscale=decode_upscale,
N=N, depth_encoding=depth_encoding,
relu_in_encoding=relu_in_encoding, skip_in_encoding=skip_in_encoding)
def forward(self, x, emb_sigma=None, physics=None, t=None, y=None, emb_in=None, img_channels=None, scale=0):
u = self.conv1(x)
u = self.nl(u)
u_2 = self.conv2(u) # Should we sum this with below?
if self.emb_physics: # TODO: add a factor (1+gain) to the emb_meas? that depends on the input snr
emb_grad = self.PhysicsBlock(u, y, physics, t, img_channels=img_channels, scale=scale)
u_1 = self.gain * emb_grad # x - grad (sign does not matter)
else:
u_1 = 0
return x + u_2 + u_1
def calculate_fan_in_and_fan_out(tensor, pytorch_style: bool = True):
"""
from https://github.com/megvii-research/basecls/blob/main/basecls/layers/wrapper.py#L77
"""
if len(tensor.shape) not in (2, 4, 5):
raise ValueError(
"fan_in and fan_out can only be computed for tensor with 2/4/5 "
"dimensions"
)
if len(tensor.shape) == 5:
# `GOIKK` to `OIKK`
tensor = tensor.reshape(-1, *tensor.shape[2:]) if pytorch_style else tensor[0]
num_input_fmaps = tensor.shape[1]
num_output_fmaps = tensor.shape[0]
receptive_field_size = 1
if len(tensor.shape) > 2:
receptive_field_size = functools.reduce(lambda x, y: x * y, tensor.shape[2:], 1)
fan_in = num_input_fmaps * receptive_field_size
fan_out = num_output_fmaps * receptive_field_size
return fan_in, fan_out
def weights_init_unext(m, gain_conv=1.0, gain_linear=1.0, init_type="ortho"):
if hasattr(m, "modules"):
for submodule in m.modules():
if not 'skip' in str(submodule):
if isinstance(submodule, nn.Conv2d) or isinstance(
submodule, nn.ConvTranspose2d
):
# nn.init.orthogonal_(submodule.weight.data, gain=1.0)
k_shape = submodule.weight.data.shape[-1]
if k_shape < 4:
nn.init.orthogonal_(submodule.weight.data, gain=0.2)
else:
_, fan_out = calculate_fan_in_and_fan_out(submodule.weight)
std = math.sqrt(2 / fan_out)
nn.init.normal_(submodule.weight, 0, std)
# if init_type == 'ortho':
# nn.init.orthogonal_(submodule.weight.data, gain=gain_conv)
# elif init_type == 'kaiming':
# nn.init.kaiming_normal_(submodule.weight.data, a=0, mode='fan_in')
# elif init_type == 'xavier':
# nn.init.xavier_normal_(submodule.weight.data, gain=gain_conv)
elif isinstance(submodule, nn.Linear):
nn.init.normal_(submodule.weight.data, std=0.01)
elif 'skip' in str(submodule):
if isinstance(submodule, nn.Conv2d) or isinstance(
submodule, nn.ConvTranspose2d
):
nn.init.ones_(submodule.weight.data)
# else:
# classname = submodule.__class__.__name__
# # print('WARNING: no init for ', classname)
def old2new(old_key):
"""
Converting old DRUNet keys to new UNExt style keys.
PATTERNS TO MATCH:
1. Case of downsampling blocks:
- for residual blocks (non-downsampling):
m_down3.2.res.0.weight -> m_down3.enc.2.conv1.weight
- for downsampling blocks:
m_down3.4.weight -> m_down3.downsample_strideconv.weight
2. Case of upsampling blocks:
- for upsampling:
m_up3.0.weight -> m_up3.upsample_convtranspose.weight
- for residual blocks:
m_up3.2.res.0.weight -> m_up3.enc.2.conv1.weight
3. Case for body blocks:
m_body.0.res.2.weight -> m_body.enc.0.conv2.weight
Args:
old_key (str): The old key from the state dictionary.
Returns:
str or None: The new key if matched, otherwise None.
"""
# Match keys with the pattern for residual blocks (downsampling)
match_residual = re.search(r"(m_down\d+)\.(\d+)\.res\.(\d+)", old_key)
if match_residual:
prefix = match_residual.group(1) # e.g., "m_down2"
index = match_residual.group(2) # e.g., "3"
conv_index = int(match_residual.group(3)) # e.g., "0"
# Determine the new conv index: 0 -> 1, 2 -> 2
new_conv_index = 1 if conv_index == 0 else 2
# Construct the new key
new_key = f"{prefix}.enc.{index}.conv{new_conv_index}.weight"
return new_key
match_residual = re.search(r"(m_up\d+)\.(\d+)\.res\.(\d+)", old_key)
if match_residual:
prefix = match_residual.group(1) # e.g., "m_down2"
index = int(match_residual.group(2)) # e.g., "3"
conv_index = int(match_residual.group(3)) # e.g., "0"
# Determine the new conv index: 0 -> 1, 2 -> 2
new_conv_index = 1 if conv_index == 0 else 2
# Construct the new key
new_key = f"{prefix}.enc.{index-1}.conv{new_conv_index}.weight"
return new_key
match_pool_downsample = re.search(r"m_down(\d+)\.4\.weight", old_key)
if match_pool_downsample:
index = match_pool_downsample.group(1) # e.g., "1" or "2"
# Construct the new key
new_key = f"pool{index}.weight"
return new_key
# Match keys for upsampling blocks
match_upsample = re.search(r"m_up(\d+)\.0\.weight", old_key)
if match_upsample:
index = match_upsample.group(1) # e.g., "1" or "2"
# Construct the new key
new_key = f"up{index}.weight"
return new_key
# Match keys for body blocks
match_body = re.search(r"(m_body)\.(\d+)\.res\.(\d+)\.weight", old_key)
if match_body:
prefix = match_body.group(1) # e.g., "m_body"
index = match_body.group(2) # e.g., "0"
conv_index = int(match_body.group(3)) # e.g., "2"
new_convindex = 1 if conv_index == 0 else 2
# Construct the new key
new_key = f"{prefix}.enc.{index}.conv{new_convindex}.weight"
return new_key
# If no patterns match, return None
return None
def update_keyvals_headtail(old_key, old_value, init_value=None, new_key_name='m_head.conv0.weight', conditioning='base'):
"""
Converting old DRUNet keys to new UNExt style keys.
KEYS do not change but weight need to be 0 padded.
Args:
old_key (str): The old key from the state dictionary.
"""
if 'head' in old_key:
if conditioning == 'base':
c_in = init_value.shape[1]
c_in_old = old_value.shape[1]
# if c_in == c_in_old:
# new_value = old_value.detach()
# elif c_in < c_in_old:
# new_value = torch.zeros_like(init_value.detach())
# new_value[:, -1:, ...] = old_value[:, -1:, ...]
# new_value[:, :c_in-1, ...] = old_value[:, :c_in-1, ...]
# if c_in == c_in_old:
# new_value = old_value.detach()
# elif c_in < c_in_old:
new_value = torch.zeros_like(init_value.detach())
new_value[:, -2:-1, ...] = old_value[:, -1:, ...]
new_value[:, -1:, ...] = old_value[:, -1:, ...]
new_value[:, :c_in-2, ...] = old_value[:, :c_in-2, ...]
return {new_key_name: new_value}
else:
c_in = init_value.shape[1]
c_in_old = old_value.shape[1]
# if c_in == c_in_old - 1:
# new_value = old_value[:, :-1, ...].detach()
# elif c_in < c_in_old - 1:
# new_value = torch.zeros_like(init_value.detach())
# new_value[:, -1:, ...] = old_value[:, -1:, ...]
# new_value[:, ...] = old_value[:, :c_in, ...]
new_value = torch.zeros_like(init_value.detach())
new_value[:, -1:-2, ...] = old_value[:, -1:, ...]
new_value[:, -1:, ...] = old_value[:, -1:, ...]
new_value[:, ...] = old_value[:, :c_in, ...]
return {new_key_name: new_value}
elif 'tail' in old_key:
c_in = init_value.shape[0]
c_in_old = old_value.shape[0]
new_value = torch.zeros_like(init_value.detach())
if c_in == c_in_old:
new_value = old_value.detach()
elif c_in < c_in_old:
new_value = torch.zeros_like(init_value.detach())
new_value[:, ...] = old_value[:c_in, ...]
return {new_key_name: new_value}
else:
print(f"Key {old_key} does not contain 'head' or 'tail'.")
# test the network
if __name__ == "__main__":
net = UNeXt()
x = torch.randn(1, 3, 128, 128)
y = net(x, 0.1)
# print(y.shape)
# print(y)
# Case for diagonal physics
# IDEA 1: kills signal in the image of A
# im_emb = dec - physics.A_adjoint_A(dec) # will be 0 in the case of denoising, but not inpainting # TODO: add gains here too
# IDEA 2: compute norm of signal in ker of A
# normker = (dec - physics.A_adjoint_A(dec)).norm() / (dec.norm() + 1e-4)
# im_emb = normker * physics.A_adjoint(self.gain_diag_x * physics.A(dec) - self.gain_diag_y * y) # will be 0 in the case of denoising, but not inpainting # TODO: add gains here too
# IDEA 3: same as above but add the pinv as well
# normker = (dec - physics.A_adjoint_A(dec)).norm() / (dec.norm() + 1e-4)
# grad_term = physics.A_adjoint(self.gain_diag_x * physics.A(dec) - self.gain_diag_y * y)
# # pinv_term = physics.A_dagger(self.gain_diagpinv_x * physics.A(dec) - self.gain_diagpinv_y * y)
# if 'tomography' in physics.__class__.__name__.lower(): # or 'pansharp' in physics.__class__.__name__.lower():
# pinv_term = physics.prox_l2(dec, self.gain_diagpinv_x ** 2 * Adec - self.gain_diagpinv_y ** 2 * y, gamma=1e9)
# else:
# pinv_term = physics.A_dagger(self.gain_diagpinv_x ** 2 * Adec - self.gain_diagpinv_y ** 2 * y) # TODO: do we set this to gain_gradx ? To get 0 during training too?? Better for denoising I guess
# im_emb = normker * (grad_term + pinv_term) # will be 0 in the case of denoising, but not inpainting # TODO: add gains here too
# # Mix it
# if hasattr(physics.noise_model, 'sigma'):
# sigma = physics.noise_model.sigma # SNR ? x /= sigma ** 2
# snr = (y.abs().mean()) / (sigma + 1e-4) # SNR equivariant # TODO: add epsilon
# snr = snr[(...,) + (None,) * (im_emb.dim() - 1)]
# else:
# snr = 1e4
#
# grad_large = emb + self.gain_diag * (1 + self.gain_noise / snr) * im_emb