Spaces:
Sleeping
Sleeping
# Code borrowed from Kai Zhang https://github.com/cszn/DPIR/tree/master/models | |
import re | |
import math | |
import functools | |
import deepinv as dinv | |
from deepinv.utils import plot, TensorList | |
import torch | |
from torch.func import vmap | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torchvision import transforms | |
from deepinv.optim.utils import conjugate_gradient | |
from physics.multiscale import MultiScaleLinearPhysics, Pad | |
from models.blocks import EquivMaxPool, AffineConv2d, ConvNextBlock2, NoiseEmbedding, MPConv, TimestepEmbedding, conv, downsample_strideconv, upsample_convtranspose | |
from models.heads import Heads, Tails, InHead, OutTail, ConvChannels, SNRModule, EquivConvModule, EquivHeads | |
cuda = True if torch.cuda.is_available() else False | |
Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor | |
### --------------- MODEL --------------- | |
class BaseEncBlock(nn.Module): | |
def __init__( | |
self, | |
in_channels, | |
out_channels, | |
bias=False, | |
mode="CRC", | |
nb=2, | |
embedding=False, | |
emb_channels=None, | |
emb_physics=False, | |
img_channels=None, | |
decode_upscale=None, | |
config='A', | |
N=4, | |
c_mult=1, | |
depth_encoding=1, | |
relu_in_encoding=False, | |
skip_in_encoding=True, | |
): | |
super(BaseEncBlock, self).__init__() | |
self.config = config | |
self.enc = nn.ModuleList( | |
[ | |
ResBlock( | |
in_channels, | |
out_channels, | |
bias=bias, | |
mode=mode, | |
embedding=embedding, | |
emb_channels=emb_channels, | |
emb_physics=emb_physics, | |
img_channels=img_channels, | |
decode_upscale=decode_upscale, | |
config=config, | |
N=N, | |
c_mult=c_mult, | |
depth_encoding=depth_encoding, | |
relu_in_encoding=relu_in_encoding, | |
skip_in_encoding=skip_in_encoding, | |
) | |
for _ in range(nb) | |
] | |
) | |
def forward(self, x, emb_sigma=None, physics=None, t=None, y=None, emb_in=None, img_channels=None, scale=0): | |
for i in range(len(self.enc)): | |
x = self.enc[i](x, emb_sigma=emb_sigma, physics=physics, t=t, y=y, img_channels=img_channels, scale=scale) | |
return x | |
class NextEncBlock(nn.Module): | |
def __init__( | |
self, in_channels, out_channels, bias=False, mode="", mult_fact=4, nb=2 | |
): | |
super(NextEncBlock, self).__init__() | |
self.enc = nn.ModuleList( | |
[ | |
ConvNextBlock2( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
bias=bias, | |
mode=mode, | |
mult_fact=mult_fact, | |
) | |
for _ in range(nb) | |
] | |
) | |
def forward(self, x, emb_sigma=None): | |
for i in range(len(self.enc)): | |
x = self.enc[i](x, emb_sigma) | |
return x | |
class UNeXt(nn.Module): | |
r""" | |
DRUNet denoiser network. | |
The network architecture is based on the paper | |
`Learning deep CNN denoiser prior for image restoration <https://arxiv.org/abs/1704.03264>`_, | |
and has a U-Net like structure, with convolutional blocks in the encoder and decoder parts. | |
The network takes into account the noise level of the input image, which is encoded as an additional input channel. | |
A pretrained network for (in_channels=out_channels=1 or in_channels=out_channels=3) | |
can be downloaded via setting ``pretrained='download'``. | |
:param int in_channels: number of channels of the input. | |
:param int out_channels: number of channels of the output. | |
:param list nc: number of convolutional layers. | |
:param int nb: number of convolutional blocks per layer. | |
:param int nf: number of channels per convolutional layer. | |
:param str act_mode: activation mode, "R" for ReLU, "L" for LeakyReLU "E" for ELU and "S" for Softplus. | |
:param str downsample_mode: Downsampling mode, "avgpool" for average pooling, "maxpool" for max pooling, and | |
"strideconv" for convolution with stride 2. | |
:param str upsample_mode: Upsampling mode, "convtranspose" for convolution transpose, "pixelsuffle" for pixel | |
shuffling, and "upconv" for nearest neighbour upsampling with additional convolution. | |
:param str, None pretrained: use a pretrained network. If ``pretrained=None``, the weights will be initialized at random | |
using Pytorch's default initialization. If ``pretrained='download'``, the weights will be downloaded from an | |
online repository (only available for the default architecture with 3 or 1 input/output channels). | |
Finally, ``pretrained`` can also be set as a path to the user's own pretrained weights. | |
See :ref:`pretrained-weights <pretrained-weights>` for more details. | |
:param bool train: training or testing mode. | |
:param str device: gpu or cpu. | |
""" | |
def __init__( | |
self, | |
in_channels=[1, 2, 3], | |
out_channels=[1, 2, 3], | |
nc=[64, 128, 256, 512], | |
nb=4, # 4 in DRUNet but out of memory | |
conv_type="next", # should be 'base' or 'next' | |
pool_type="next", # should be 'base' or 'next' | |
cond_type="base", # conditioning, should be 'base' or 'edm' | |
device=None, | |
bias=False, | |
mode="", | |
residual=False, | |
act_mode="R", | |
layer_scale_init_value=1e-6, | |
init_type="ortho", | |
gain_init_conv=1.0, | |
gain_init_linear=1.0, | |
drop_prob=0.0, | |
replk=False, | |
mult_fact=4, | |
antialias="gaussian", | |
emb_physics=False, | |
config='A', | |
pretrained_pth=None, | |
N=4, | |
c_mult=1, | |
depth_encoding=1, | |
relu_in_encoding=False, | |
skip_in_encoding=True, | |
): | |
super(UNeXt, self).__init__() | |
self.residual = residual | |
self.conv_type = conv_type | |
self.pool_type = pool_type | |
self.emb_physics = emb_physics | |
self.config = config | |
self.in_channels = in_channels | |
self.fact_realign = torch.nn.Parameter(torch.tensor([1.0], device=device)) | |
self.separate_head = isinstance(in_channels, list) | |
assert cond_type in ["base", "edm"], "cond_type should be 'base' or 'edm'" | |
self.cond_type = cond_type | |
if self.cond_type == "base": | |
if self.config != 'E': | |
if isinstance(in_channels, list): | |
in_channels_first = [] | |
for i in range(len(in_channels)): | |
in_channels_first.append(in_channels[i] + 2) | |
else: # old head | |
in_channels_first = in_channels + 1 | |
else: | |
in_channels_first = in_channels | |
else: | |
in_channels_first = in_channels | |
self.noise_embedding = NoiseEmbedding( | |
num_channels=in_channels, emb_channels=max(nc), device=device | |
) | |
self.timestep_embedding = lambda x: x | |
# check if in_channels is a list | |
self.m_head = InHead(in_channels_first, nc[0]) | |
if conv_type == "next": | |
self.m_down1 = NextEncBlock( | |
nc[0], nc[0], bias=bias, mode=mode, mult_fact=mult_fact, nb=nb | |
) | |
self.m_down2 = NextEncBlock( | |
nc[1], nc[1], bias=bias, mode=mode, mult_fact=mult_fact, nb=nb | |
) | |
self.m_down3 = NextEncBlock( | |
nc[2], nc[2], bias=bias, mode=mode, mult_fact=mult_fact, nb=nb | |
) | |
self.m_body = NextEncBlock( | |
nc[3], nc[3], bias=bias, mode=mode, mult_fact=mult_fact, nb=nb | |
) | |
self.m_up3 = NextEncBlock( | |
nc[2], nc[2], bias=bias, mode=mode, mult_fact=mult_fact, nb=nb | |
) | |
self.m_up2 = NextEncBlock( | |
nc[1], nc[1], bias=bias, mode=mode, mult_fact=mult_fact, nb=nb | |
) | |
self.m_up1 = NextEncBlock( | |
nc[0], nc[0], bias=bias, mode=mode, mult_fact=mult_fact, nb=nb | |
) | |
elif conv_type == "base": | |
embedding = ( | |
False if cond_type == "base" else True | |
) | |
emb_channels = max(nc) | |
self.m_down1 = BaseEncBlock( | |
nc[0], | |
nc[0], | |
bias=False, | |
mode="CRC", | |
nb=nb, | |
embedding=embedding, | |
emb_channels=emb_channels, | |
emb_physics=emb_physics, | |
img_channels=in_channels, | |
decode_upscale=1, | |
config=config, | |
N=N, | |
c_mult=c_mult, | |
depth_encoding=depth_encoding, | |
relu_in_encoding=relu_in_encoding, | |
skip_in_encoding=skip_in_encoding, | |
) | |
self.m_down2 = BaseEncBlock( | |
nc[1], | |
nc[1], | |
bias=False, | |
mode="CRC", | |
nb=nb, | |
embedding=embedding, | |
emb_channels=emb_channels, | |
emb_physics=emb_physics, | |
img_channels=in_channels, | |
decode_upscale=2, | |
config=config, | |
N=N, | |
c_mult=c_mult, | |
depth_encoding=depth_encoding, | |
relu_in_encoding=relu_in_encoding, | |
skip_in_encoding=skip_in_encoding, | |
) | |
self.m_down3 = BaseEncBlock( | |
nc[2], | |
nc[2], | |
bias=False, | |
mode="CRC", | |
nb=nb, | |
embedding=embedding, | |
emb_channels=emb_channels, | |
emb_physics=emb_physics, | |
img_channels=in_channels, | |
decode_upscale=4, | |
config=config, | |
N=N, | |
c_mult=c_mult, | |
depth_encoding=depth_encoding, | |
relu_in_encoding=relu_in_encoding, | |
skip_in_encoding=skip_in_encoding, | |
) | |
self.m_body = BaseEncBlock( | |
nc[3], | |
nc[3], | |
bias=False, | |
mode="CRC", | |
nb=nb, | |
embedding=embedding, | |
emb_channels=emb_channels, | |
emb_physics=emb_physics, | |
img_channels=in_channels, | |
decode_upscale=8, | |
config=config, | |
N=N, | |
c_mult=c_mult, | |
depth_encoding=depth_encoding, | |
relu_in_encoding=relu_in_encoding, | |
skip_in_encoding=skip_in_encoding, | |
) | |
self.m_up3 = BaseEncBlock( | |
nc[2], | |
nc[2], | |
bias=False, | |
mode="CRC", | |
nb=nb, | |
embedding=embedding, | |
emb_channels=emb_channels, | |
emb_physics=emb_physics, | |
img_channels=in_channels, | |
decode_upscale=4, | |
config=config, | |
N=N, | |
c_mult=c_mult, | |
depth_encoding=depth_encoding, | |
relu_in_encoding=relu_in_encoding, | |
skip_in_encoding=skip_in_encoding, | |
) | |
self.m_up2 = BaseEncBlock( | |
nc[1], | |
nc[1], | |
bias=False, | |
mode="CRC", | |
nb=nb, | |
embedding=embedding, | |
emb_channels=emb_channels, | |
emb_physics=emb_physics, | |
img_channels=in_channels, | |
decode_upscale=2, | |
config=config, | |
N=N, | |
c_mult=c_mult, | |
depth_encoding=depth_encoding, | |
relu_in_encoding=relu_in_encoding, | |
skip_in_encoding=skip_in_encoding, | |
) | |
self.m_up1 = BaseEncBlock( | |
nc[0], | |
nc[0], | |
bias=False, | |
mode="CRC", | |
nb=nb, | |
embedding=embedding, | |
emb_channels=emb_channels, | |
emb_physics=emb_physics, | |
img_channels=in_channels, | |
decode_upscale=1, | |
config=config, | |
N=N, | |
c_mult=c_mult, | |
depth_encoding=depth_encoding, | |
relu_in_encoding=relu_in_encoding, | |
skip_in_encoding=skip_in_encoding, | |
) | |
else: | |
raise NotImplementedError("conv_type should be 'base' or 'next'") | |
if pool_type == "next_max": | |
self.pool1 = EquivMaxPool( | |
antialias=antialias, | |
in_channels=nc[0], | |
out_channels=nc[1], | |
device=device, | |
) | |
self.pool2 = EquivMaxPool( | |
antialias=antialias, | |
in_channels=nc[1], | |
out_channels=nc[2], | |
device=device, | |
) | |
self.pool3 = EquivMaxPool( | |
antialias=antialias, | |
in_channels=nc[2], | |
out_channels=nc[3], | |
device=device, | |
) | |
elif pool_type == "base": | |
self.pool1 = downsample_strideconv(nc[0], nc[1], bias=False, mode="2") | |
self.pool2 = downsample_strideconv(nc[1], nc[2], bias=False, mode="2") | |
self.pool3 = downsample_strideconv(nc[2], nc[3], bias=False, mode="2") | |
self.up3 = upsample_convtranspose(nc[3], nc[2], bias=False, mode="2") | |
self.up2 = upsample_convtranspose(nc[2], nc[1], bias=False, mode="2") | |
self.up1 = upsample_convtranspose(nc[1], nc[0], bias=False, mode="2") | |
else: | |
raise NotImplementedError("pool_type should be 'base' or 'next'") | |
self.m_tail = OutTail(nc[0], in_channels) | |
if conv_type == "base": | |
init_func = functools.partial( | |
weights_init_unext, init_type="ortho", gain_conv=0.2 | |
) | |
self.apply(init_func) | |
else: | |
init_func = functools.partial( | |
weights_init_unext, | |
init_type=init_type, | |
gain_conv=gain_init_conv, | |
gain_linear=gain_init_linear, | |
) | |
self.apply(init_func) | |
if pretrained_pth=='jz': | |
pth = '/lustre/fswork/projects/rech/nyd/commun/mterris/base_checkpoints/drunet_deepinv_color_finetune_22k.pth' | |
self.load_drunet_weights(pth) | |
elif pretrained_pth is not None: | |
self.load_drunet_weights(pretrained_pth) | |
if self.config == 'D': | |
# deactivate grad for layers that do not contain the string "PhysicsBlock" or "gain" or "fact_realign" | |
for name, param in self.named_parameters(): | |
if 'PhysicsBlock' not in name and 'gain' not in name and 'fact_realign' not in name and "m_head" not in name and "m_tail" not in name: | |
param.requires_grad = False | |
if device is not None: | |
self.to(device) | |
def load_drunet_weights(self, ckpt_pth): | |
state_dict = torch.load(ckpt_pth, map_location=lambda storage, loc: storage) | |
new_state_dict = {} | |
matched_keys = [] # List to store successfully matched keys | |
unmatched_keys = [] # List to store keys that were not matched or excluded | |
excluded_keys = [] # List to store excluded keys | |
# Define patterns to exclude | |
exclude_patterns = ["head", "tail"] | |
# Dealing with regular keys | |
for old_key, value in state_dict.items(): | |
# Skip keys containing any of the excluded patterns | |
if any(excluded in old_key for excluded in exclude_patterns): | |
excluded_keys.append(old_key) | |
continue # Skip further processing for this key | |
new_key = old2new(old_key) | |
if new_key is not None: | |
matched_keys.append((old_key, new_key)) # Record the matched keys | |
new_state_dict[new_key] = value | |
else: | |
unmatched_keys.append(old_key) # Record unmatched keys | |
# TODO: clean this | |
for excluded_key in excluded_keys: | |
if isinstance(self.in_channels, list): | |
for i, in_channel in enumerate(self.in_channels): | |
# print('Dealing with conv ', i) | |
new_key = f"m_head.conv{i}.weight" | |
if 'head' in excluded_key: | |
new_key = f"m_head.conv{i}.weight" | |
# new_key = f"m_head.head.conv{i}.weight" | |
if 'tail' in excluded_key: | |
new_key = f"m_tail.conv{i}.weight" | |
# DEBUG print all keys of state dict: | |
# print(state_dict.keys()) | |
# print(self.state_dict().keys()) | |
conditioning = 'base' | |
# if self.config == 'E': | |
# conditioning = False | |
new_kv = update_keyvals_headtail(excluded_key, | |
state_dict[excluded_key], | |
init_value=self.state_dict()[new_key], | |
new_key_name=new_key, | |
conditioning=conditioning) | |
new_state_dict.update(new_kv) | |
# print(new_kv.keys()) | |
else: | |
new_kv = update_keyvals_headtail(excluded_key, state_dict[excluded_key]) | |
new_state_dict.update(new_kv) | |
# Display matched keys | |
print("Matched keys:") | |
for old_key, new_key in matched_keys: | |
print(f"{old_key} -> {new_key}") | |
# Load updated state dict into the model | |
self.load_state_dict(new_state_dict, strict=False) | |
# Display unmatched keys | |
print("\nUnmatched keys:") | |
for unmatched_key in unmatched_keys: | |
print(unmatched_key) | |
print("Weights loaded from ", ckpt_pth) | |
def constant2map(self, value, x): | |
if isinstance(value, torch.Tensor): | |
if value.ndim > 0: | |
value_map = value.view(x.size(0), 1, 1, 1) | |
value_map = value_map.expand(-1, 1, x.size(2), x.size(3)) | |
else: | |
value_map = torch.ones( | |
(x.size(0), 1, x.size(2), x.size(3)), device=x.device | |
) * value[None, None, None, None].to(x.device) | |
else: | |
value_map = ( | |
torch.ones((x.size(0), 1, x.size(2), x.size(3)), device=x.device) | |
* value | |
) | |
return value_map | |
def base_conditioning(self, x, sigma, gamma): | |
noise_level_map = self.constant2map(sigma, x) | |
gamma_map = self.constant2map(gamma, x) | |
return torch.cat((x, noise_level_map, gamma_map), 1) | |
def realign_input(self, x, physics, y): | |
if hasattr(physics, "factor"): | |
f = physics.factor | |
elif hasattr(physics, "base") and hasattr(physics.base, "factor"): | |
f = physics.base.factor | |
elif hasattr(physics, "base") and hasattr(physics.base, "base") and hasattr(physics.base.base, "factor"): | |
f = physics.base.base.factor | |
else: | |
f = 1.0 | |
sigma = 1e-6 # default value | |
if hasattr(physics.noise_model, 'sigma'): | |
sigma = physics.noise_model.sigma | |
if hasattr(physics, 'base') and hasattr(physics.base, 'noise_model') and hasattr(physics.base.noise_model, 'sigma'): | |
sigma = physics.base.noise_model.sigma | |
if hasattr(physics, 'base') and hasattr(physics.base, 'base') and hasattr(physics.base.base, 'noise_model') and hasattr(physics.base.base.noise_model, 'sigma'): | |
sigma = physics.base.base.noise_model.sigma | |
if isinstance(y, TensorList): | |
num = (y[0].reshape(y[0].shape[0], -1).abs().mean(1)) | |
else: | |
num = (y.reshape(y.shape[0], -1).abs().mean(1)) | |
snr = num / (sigma + 1e-4) # SNR equivariant | |
gamma = 1 / (1e-4 + 1 / (snr * f **2 )) # TODO: check square-root / mean / check if we need to add a factor in front ? | |
gamma = gamma[(...,) + (None,) * (x.dim() - 1)] | |
model_input = physics.prox_l2(x, y, gamma=gamma * self.fact_realign) | |
return model_input | |
def forward_unet(self, x0, sigma=None, gamma=None, physics=None, t=None, y=None, img_channels=None): | |
# list_values = [] | |
if self.cond_type == "base": | |
# if self.config != 'E': | |
x0 = self.base_conditioning(x0, sigma, gamma) | |
emb_sigma = None | |
else: | |
emb_sigma = self.noise_embedding( | |
sigma | |
) # This only if the embedding is the non-basic one from drunet | |
emb_timestep = self.timestep_embedding(t) | |
x1 = self.m_head(x0) # old | |
# x1 = self.m_head(x0, emb_sigma=emb_sigma, physics=physics, t=emb_timestep, y=y, img_channels=img_channels) | |
# list_values.append(x1.abs().mean()) | |
if self.config == 'G': | |
x1_, emb1_ = self.m_down1(x1, emb_sigma=emb_sigma, physics=physics, t=emb_timestep, y=y, img_channels=img_channels) | |
else: | |
x1_ = self.m_down1(x1, emb_sigma=emb_sigma, physics=physics, t=emb_timestep, y=y, img_channels=img_channels, scale=0) | |
x2 = self.pool1(x1_) | |
# list_values.append(x2.abs().mean()) | |
if self.config == 'G': | |
x3_, emb3_ = self.m_down2(x2, emb_sigma=emb_sigma, physics=physics, t=emb_timestep, y=y, img_channels=img_channels) | |
else: | |
x3_ = self.m_down2(x2, emb_sigma=emb_sigma, physics=physics, t=emb_timestep, y=y, img_channels=img_channels, scale=1) | |
x3 = self.pool2(x3_) | |
# list_values.append(x3.abs().mean()) | |
if self.config == 'G': | |
x4_, emb4_ = self.m_down3(x3, emb_sigma=emb_sigma, physics=physics, t=emb_timestep, y=y, img_channels=img_channels) | |
else: | |
x4_ = self.m_down3(x3, emb_sigma=emb_sigma, physics=physics, t=emb_timestep, y=y, img_channels=img_channels, scale=2) | |
x4 = self.pool3(x4_) | |
# issue: https://github.com/matthieutrs/ram_project/issues/1 | |
# solution 1: using .contiguous() below | |
# solution 2: using a print statement that magically solves the issue | |
###print(x4.is_contiguous()) | |
# list_values.append(x4.abs().mean()) | |
if self.config == 'G': | |
x, _ = self.m_body(x4, emb_sigma=emb_sigma, physics=physics, t=emb_timestep, y=y, img_channels=img_channels) | |
else: | |
x = self.m_body(x4, emb_sigma=emb_sigma, physics=physics, t=emb_timestep, y=y, img_channels=img_channels, scale=3) | |
# list_values.append(x.abs().mean()) | |
if self.pool_type == "next" or self.pool_type == "next_max": | |
x = self.pool3.upscale(x + x4) | |
else: | |
x = self.up3(x + x4) | |
if self.config == 'G': | |
x, _ = self.m_up3(x, emb_sigma=emb_sigma, physics=physics, t=emb_timestep, y=y, emb_in=emb4_, img_channels=img_channels) | |
else: | |
x = self.m_up3(x, emb_sigma=emb_sigma, physics=physics, t=emb_timestep, y=y, img_channels=img_channels, scale=2) | |
# list_values.append(x.abs().mean()) | |
if self.pool_type == "next" or self.pool_type == "next_max": | |
x = self.pool2.upscale(x + x3) | |
else: | |
x = self.up2(x + x3) | |
if self.config == 'G': | |
x, _ = self.m_up2(x, emb_sigma=emb_sigma, physics=physics, t=emb_timestep, y=y, emb_in=emb3_, img_channels=img_channels) | |
else: | |
x = self.m_up2(x, emb_sigma=emb_sigma, physics=physics, t=emb_timestep, y=y, img_channels=img_channels, scale=1) | |
# list_values.append(x.abs().mean()) | |
if self.pool_type == "next" or self.pool_type == "next_max": | |
x = self.pool1.upscale(x + x2) | |
else: | |
x = self.up1(x + x2) | |
if self.config == 'G': | |
x, _ = self.m_up1(x, emb_sigma=emb_sigma, physics=physics, t=emb_timestep, y=y, emb_in=emb1_, img_channels=img_channels) | |
else: | |
x = self.m_up1(x, emb_sigma=emb_sigma, physics=physics, t=emb_timestep, y=y, img_channels=img_channels, scale=0) | |
# list_values.append(x.abs().mean()) | |
if self.separate_head: | |
x = self.m_tail(x + x1, img_channels) | |
else: | |
x = self.m_tail(x + x1) | |
return x | |
def forward(self, x, sigma=None, gamma=None, physics=None, t=None, y=None): | |
r""" | |
Run the denoiser on image with noise level :math:`\sigma`. | |
:param torch.Tensor x: noisy image | |
:param float, torch.Tensor sigma: noise level. If ``sigma`` is a float, it is used for all images in the batch. | |
If ``sigma`` is a tensor, it must be of shape ``(batch_size,)``. | |
""" | |
img_channels = x.shape[1] # x_n_chan = x.shape[1] | |
if self.emb_physics: | |
physics = MultiScaleLinearPhysics(physics, x.shape[-3:], device=x.device) | |
if self.separate_head and img_channels not in self.in_channels: | |
raise ValueError(f"Input image has {img_channels} channels, but the network only have heads for {self.in_channels} channels.") | |
if y is not None: | |
x = self.realign_input(x, physics, y) | |
x = self.forward_unet(x, sigma=sigma, gamma=gamma, physics=physics, t=t, y=y, img_channels=img_channels) | |
return x | |
def krylov_embeddings_old(y, p, factor, v=None, N=4, feat_size=1, x_init=None, img_channels=3): | |
if x_init is None: | |
x = p.A_adjoint(y) | |
else: | |
x = x_init[:, :img_channels, ...] | |
if feat_size > 1: | |
_, C, _, _ = x.shape | |
if v is None: | |
v = torch.zeros_like(x).repeat(1, N-1, 1, 1) | |
out = x - v[:, :C, ...] | |
norm = factor ** 2 | |
A = lambda u: p.A_adjoint(p.A(u)) * norm | |
for i in range(N-1): | |
x = A(x) - v[:, (i+1) * C:(i+2) * C, ...] | |
out = torch.cat([out, x], dim=1) | |
else: | |
if v is None: | |
v = torch.zeros_like(x) | |
out = x - v | |
norm = factor ** 2 | |
A = lambda u: p.A_adjoint(p.A(u)) * norm | |
for i in range(N-1): | |
x = A(x) - v | |
out = torch.cat([out, x], dim=1) | |
return out | |
def krylov_embeddings(y, p, factor, v=None, N=4, x_init=None, img_channels=3): | |
""" | |
Efficient Krylov subspace embedding computation with parallel processing. | |
Args: | |
y (torch.Tensor): The input tensor. | |
p: An object with A and A_adjoint methods (linear operator). | |
factor (float): Scaling factor. | |
v (torch.Tensor, optional): Precomputed values to subtract from Krylov sequence. Defaults to None. | |
N (int, optional): Number of Krylov iterations. Defaults to 4. | |
feat_size (int, optional): Feature expansion size. Defaults to 1. | |
x_init (torch.Tensor, optional): Initial guess. Defaults to None. | |
img_channels (int, optional): Number of image channels. Defaults to 3. | |
Returns: | |
torch.Tensor: The Krylov embeddings. | |
""" | |
if x_init is None: | |
x = p.A_adjoint(y) | |
else: | |
x = x_init.clone() # Extract the first img_channels | |
norm = factor ** 2 # Precompute normalization factor | |
AtA = lambda u: p.A_adjoint(p.A(u)) * norm # Define the linear operator | |
v = v if v is not None else torch.zeros_like(x) | |
out = x.clone() | |
# Compute Krylov basis | |
x_k = x.clone() | |
for i in range(N-1): | |
x_k = AtA(x_k) - v | |
out = torch.cat([out, x_k], dim=1) | |
return out | |
def grad_embeddings(y, p, factor, v=None, N=4, feat_size=1): | |
Aty = p.A_adjoint(y) | |
if feat_size > 1: | |
_, C, _, _ = Aty.shape | |
if v is None: | |
v = torch.zeros_like(Aty).repeat(1, N-1, 1, 1) | |
out = v[:, :C, ...] - Aty | |
norm = factor ** 2 | |
A = lambda u: p.A_adjoint(p.A(u)) * norm | |
for i in range(N-1): | |
x = A(v[:, (i+1) * C:(i+2) * C, ...]) - Aty | |
out = torch.cat([out, x], dim=1) | |
else: | |
if v is None: | |
v = torch.zeros_like(Aty) | |
out = v - Aty | |
norm = factor ** 2 | |
A = lambda u: p.A_adjoint(p.A(u)) * norm | |
for i in range(N-1): | |
x = A(v) - Aty | |
out = torch.cat([out, x], dim=1) | |
return out | |
def prox_embeddings(y, p, factor, v=None, N=4): | |
x = p.A_adjoint(y) | |
B, C, H, W = x.shape | |
if v is None: | |
v = torch.zeros_like(x) | |
v = v.repeat(1, N - 1, 1, 1) | |
gamma = torch.logspace(-4, -1, N-1, device=x.device).repeat_interleave(C).unsqueeze(0).unsqueeze(-1).unsqueeze(-1) | |
norm = factor ** 2 | |
A_sub = lambda u: torch.cat([p.A_adjoint(p.A(u[:, i * C:(i+1) * C, ...])) * norm for i in range(N-1)], dim=1) | |
A = lambda u: A_sub(u) + (u - v) * gamma | |
u_hat = conjugate_gradient(A, x.repeat(1, N-1, 1, 1), max_iter=3, tol=1e-3) | |
u_hat = torch.cat([u_hat, x], dim=1) | |
return u_hat | |
# -------------------------------------------- | |
# Res Block: x + conv(relu(conv(x))) | |
# -------------------------------------------- | |
class MeasCondBlock(nn.Module): | |
def __init__( | |
self, | |
out_channels=64, | |
img_channels=None, | |
decode_upscale=None, | |
config = 'A', | |
N=4, | |
depth_encoding=1, | |
relu_in_encoding=False, | |
skip_in_encoding=True, | |
c_mult=1, | |
): | |
super(MeasCondBlock, self).__init__() | |
self.separate_head = isinstance(img_channels, list) | |
self.config = config | |
assert img_channels is not None, "decode_dimensions should be provided" | |
assert decode_upscale is not None, "decode_upscale should be provided" | |
# if self.separate_head: | |
if self.config == 'A': | |
self.relu_encoding = nn.ReLU(inplace=False) | |
self.N = N | |
self.c_mult = c_mult | |
self.encoding_conv = Heads(img_channels, out_channels, depth=depth_encoding, scale=1, bias=False, c_mult=self.c_mult, relu_in=relu_in_encoding, skip_in=skip_in_encoding) | |
if self.config == 'B': | |
self.N = N | |
self.c_mult = c_mult | |
self.relu_encoding = nn.ReLU(inplace=False) | |
self.decoding_conv = Tails(out_channels, img_channels, depth=1, scale=1, bias=False, c_mult=self.c_mult) | |
self.encoding_conv = Heads(img_channels, out_channels, depth=depth_encoding, scale=1, bias=False, c_mult=self.c_mult, relu_in=relu_in_encoding, skip_in=skip_in_encoding) | |
if self.config == 'C': | |
self.N = N | |
self.c_mult = c_mult | |
self.relu_encoding = nn.ReLU(inplace=False) | |
self.decoding_conv = Tails(out_channels, img_channels, depth=1, scale=1, bias=False, c_mult=self.c_mult) | |
self.encoding_conv = Heads(img_channels, out_channels, depth=depth_encoding, scale=1, bias=False, c_mult=self.c_mult*N, c_add=N, relu_in=relu_in_encoding, skip_in=skip_in_encoding) | |
elif self.config == 'D': | |
self.N = N | |
self.c_mult = c_mult | |
self.relu_encoding = nn.ReLU(inplace=False) | |
self.decoding_conv = Tails(out_channels, img_channels, depth=1, scale=1, bias=False, c_mult=self.c_mult) | |
self.encoding_conv = Heads(img_channels, out_channels, depth=depth_encoding, scale=1, bias=False, c_mult=self.c_mult*N, c_add=N, relu_in=relu_in_encoding, skip_in=skip_in_encoding) | |
self.gain = torch.nn.Parameter(torch.tensor([1.0]), requires_grad=True) | |
self.gain_gradx = torch.nn.Parameter(torch.tensor([1e-2]), requires_grad=True) | |
self.gain_grady = torch.nn.Parameter(torch.tensor([1e-2]), requires_grad=True) | |
self.gain_pinvx = torch.nn.Parameter(torch.tensor([1e-2]), requires_grad=True) | |
self.gain_pinvy = torch.nn.Parameter(torch.tensor([1e-2]), requires_grad=True) | |
def forward(self, x, y, physics, t, emb_in=None, img_channels=None, scale=1): | |
if self.config == 'A': | |
return self.measurement_conditioning_config_A(x, y, physics, img_channels=img_channels, scale=scale) | |
elif self.config == 'F': | |
return self.measurement_conditioning_config_F(x, y, physics, img_channels=img_channels, scale=scale) | |
elif self.config == 'B': | |
return self.measurement_conditioning_config_B(x, y, physics, img_channels=img_channels, scale=scale) | |
elif self.config == 'C': | |
return self.measurement_conditioning_config_C(x, y, physics, img_channels=img_channels, scale=scale) | |
elif self.config == 'D': | |
return self.measurement_conditioning_config_D(x, y, physics, img_channels=img_channels, scale=scale) | |
elif self.config == 'E': | |
return self.measurement_conditioning_config_E(x, y, physics, img_channels=img_channels, scale=scale) | |
else: | |
raise NotImplementedError('Config not implemented') | |
def measurement_conditioning_config_A(self, x, y, physics, img_channels, scale=0): | |
physics.set_scale(scale) | |
factor = 2**(scale) | |
meas = krylov_embeddings(y, physics, factor, N=self.N, img_channels=img_channels) | |
cond = self.encoding_conv(meas) | |
emb = self.relu_encoding(cond) | |
return emb | |
def measurement_conditioning_config_B(self, x, y, physics, img_channels, scale=0): | |
physics.set_scale(scale) | |
dec = self.decoding_conv(x, img_channels) | |
factor = 2**(scale) | |
meas = krylov_embeddings(y, physics, factor, v=dec, N=self.N, img_channels=img_channels) | |
cond = self.encoding_conv(meas) | |
emb = self.relu_encoding(cond) | |
return emb # * sigma_emb | |
def measurement_conditioning_config_C(self, x, y, physics, img_channels, scale=0): | |
physics.set_scale(scale) | |
dec = self.decoding_conv(x, img_channels) | |
factor = 2**(scale) | |
meas_y = krylov_embeddings(y, physics, factor, N=self.N, img_channels=img_channels) | |
meas_dec = krylov_embeddings(y, physics, factor, N=self.N, x_init=dec[:, :img_channels, ...], img_channels=img_channels) | |
for c in range(1, self.c_mult): | |
meas_cur = krylov_embeddings(y, physics, factor, N=self.N, x_init=dec[:, img_channels*c:img_channels*(c+1)], | |
img_channels=img_channels) | |
meas_dec = torch.cat([meas_dec, meas_cur], dim=1) | |
meas = torch.cat([meas_y, meas_dec], dim=1) | |
cond = self.encoding_conv(meas) | |
emb = self.relu_encoding(cond) | |
return emb | |
def measurement_conditioning_config_D(self, x, y, physics, img_channels, scale=0): | |
physics.set_scale(scale) | |
dec = self.decoding_conv(x, img_channels) | |
factor = 2**(scale) | |
meas_y = krylov_embeddings(y, physics, factor, N=self.N, img_channels=img_channels) | |
meas_dec = krylov_embeddings(y, physics, factor, N=self.N, x_init=dec[:, :img_channels, ...], img_channels=img_channels) | |
for c in range(1, self.c_mult): | |
meas_cur = krylov_embeddings(y, physics, factor, N=self.N, x_init=dec[:, img_channels*c:img_channels*(c+1)], | |
img_channels=img_channels) | |
meas_dec = torch.cat([meas_dec, meas_cur], dim=1) | |
meas = torch.cat([meas_y, meas_dec], dim=1) | |
cond = self.encoding_conv(meas) | |
emb = self.relu_encoding(cond) | |
return cond | |
def measurement_conditioning_config_F(self, x, y, physics, img_channels): | |
dec_large = self.decoding_conv(x, img_channels) # go from shape = (B, C, H, W) to (B, 64, 64, 64) (independent of modality) | |
dec = self.relu_decoding(dec_large) | |
Adec = physics.A(dec) | |
grad = physics.A_adjoint(self.gain_gradx ** 2 * Adec - self.gain_grady ** 2 * y) # TODO: check if we need to have L2 (depending on noise nature, can be automated) | |
if 'tomography' in physics.__class__.__name__.lower(): # or 'pansharp' in physics.__class__.__name__.lower(): | |
pinv = physics.prox_l2(dec, self.gain_pinvx ** 2 * Adec - self.gain_pinvy ** 2 * y, gamma=1e9) | |
else: | |
pinv = physics.A_dagger(self.gain_pinvx ** 2 * Adec - self.gain_pinvy ** 2 * y) # TODO: do we set this to gain_gradx ? To get 0 during training too?? Better for denoising I guess | |
# Mix grad and pinv | |
emb = grad - pinv # will be 0 in the case of denoising, but also inpainting | |
im_emb = dec - physics.A_adjoint_A(dec) # will be 0 in the case of denoising, but not inpainting # TODO: add gains here too | |
grad_large = emb + im_emb | |
emb_grad = self.encoding_conv(grad_large) | |
emb_grad = self.relu_encoding(emb_grad) | |
return emb_grad | |
def measurement_conditioning_config_E(self, x, y, physics, img_channels, scale=1): | |
dec = self.decoding_conv(x, img_channels) # go from shape = (B, C, H, W) to (B, 64, 64, 64) (independent of modality) | |
physics.set_scale(scale) | |
# TODO: check things are batched | |
f = physics.factor if hasattr(physics, "factor") else 1.0 | |
err = (physics.A_adjoint(physics.A(dec) - y)) | |
# snr = self.snr_module(err) | |
snr = dec.reshape(dec.shape[0], -1).abs().mean(dim=1) / (err.reshape(err.shape[0], -1).abs().mean(dim=1) + 1e-4) | |
gamma = 1 / (1e-4 + 1 / (snr * f ** 2 + 1)) # TODO: check square-root / mean / check if we need to add a factor in front | |
gamma_est = gamma[(...,) + (None,) * (dec.dim() - 1)] | |
prox = physics.prox_l2(dec, y, gamma=gamma_est * self.fact_prox) | |
emb = self.fact_prox_skip_1 * prox + self.fact_prox_skip_2 * dec | |
emb_grad = self.encoding_conv(emb) | |
emb_grad = self.relu_encoding(emb_grad) | |
return emb_grad | |
class ResBlock(nn.Module): | |
def __init__( | |
self, | |
in_channels=64, | |
out_channels=64, | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
bias=True, | |
mode="CRC", | |
negative_slope=0.2, | |
embedding=False, | |
emb_channels=None, | |
emb_physics=False, | |
img_channels=None, | |
decode_upscale=None, | |
config = 'A', | |
head=False, | |
tail=False, | |
N=4, | |
c_mult=1, | |
depth_encoding=1, | |
relu_in_encoding=False, | |
skip_in_encoding=True, | |
): | |
super(ResBlock, self).__init__() | |
if not head and not tail: | |
assert in_channels == out_channels, "Only support in_channels==out_channels." | |
self.separate_head = isinstance(img_channels, list) | |
self.config = config | |
self.is_head = head | |
self.is_tail = tail | |
if self.is_head: | |
self.head = InHead(img_channels, out_channels, input_layer=True) | |
# if self.is_tail: | |
# self.tail = OutTail(in_channels, out_channels) | |
if not self.is_head and not self.is_tail: | |
self.conv1 = conv( | |
in_channels, | |
out_channels, | |
kernel_size, | |
stride, | |
padding, | |
bias, | |
"C", | |
negative_slope, | |
) | |
self.nl = nn.ReLU(inplace=True) | |
self.conv2 = conv( | |
out_channels, | |
out_channels, | |
kernel_size, | |
stride, | |
padding, | |
bias, | |
"C", | |
negative_slope, | |
) | |
if embedding: | |
self.gain = torch.nn.Parameter(torch.tensor([1.0]), requires_grad=True) | |
self.emb_linear = MPConv(emb_channels, out_channels, kernel=[]) | |
self.emb_physics = emb_physics | |
if self.emb_physics: | |
self.gain = torch.nn.Parameter(torch.tensor([1.0]), requires_grad=True) | |
self.PhysicsBlock = MeasCondBlock(out_channels=out_channels, config=config, c_mult=c_mult, | |
img_channels=img_channels, decode_upscale=decode_upscale, | |
N=N, depth_encoding=depth_encoding, | |
relu_in_encoding=relu_in_encoding, skip_in_encoding=skip_in_encoding) | |
def forward(self, x, emb_sigma=None, physics=None, t=None, y=None, emb_in=None, img_channels=None, scale=0): | |
u = self.conv1(x) | |
u = self.nl(u) | |
u_2 = self.conv2(u) # Should we sum this with below? | |
if self.emb_physics: # TODO: add a factor (1+gain) to the emb_meas? that depends on the input snr | |
emb_grad = self.PhysicsBlock(u, y, physics, t, img_channels=img_channels, scale=scale) | |
u_1 = self.gain * emb_grad # x - grad (sign does not matter) | |
else: | |
u_1 = 0 | |
return x + u_2 + u_1 | |
def calculate_fan_in_and_fan_out(tensor, pytorch_style: bool = True): | |
""" | |
from https://github.com/megvii-research/basecls/blob/main/basecls/layers/wrapper.py#L77 | |
""" | |
if len(tensor.shape) not in (2, 4, 5): | |
raise ValueError( | |
"fan_in and fan_out can only be computed for tensor with 2/4/5 " | |
"dimensions" | |
) | |
if len(tensor.shape) == 5: | |
# `GOIKK` to `OIKK` | |
tensor = tensor.reshape(-1, *tensor.shape[2:]) if pytorch_style else tensor[0] | |
num_input_fmaps = tensor.shape[1] | |
num_output_fmaps = tensor.shape[0] | |
receptive_field_size = 1 | |
if len(tensor.shape) > 2: | |
receptive_field_size = functools.reduce(lambda x, y: x * y, tensor.shape[2:], 1) | |
fan_in = num_input_fmaps * receptive_field_size | |
fan_out = num_output_fmaps * receptive_field_size | |
return fan_in, fan_out | |
def weights_init_unext(m, gain_conv=1.0, gain_linear=1.0, init_type="ortho"): | |
if hasattr(m, "modules"): | |
for submodule in m.modules(): | |
if not 'skip' in str(submodule): | |
if isinstance(submodule, nn.Conv2d) or isinstance( | |
submodule, nn.ConvTranspose2d | |
): | |
# nn.init.orthogonal_(submodule.weight.data, gain=1.0) | |
k_shape = submodule.weight.data.shape[-1] | |
if k_shape < 4: | |
nn.init.orthogonal_(submodule.weight.data, gain=0.2) | |
else: | |
_, fan_out = calculate_fan_in_and_fan_out(submodule.weight) | |
std = math.sqrt(2 / fan_out) | |
nn.init.normal_(submodule.weight, 0, std) | |
# if init_type == 'ortho': | |
# nn.init.orthogonal_(submodule.weight.data, gain=gain_conv) | |
# elif init_type == 'kaiming': | |
# nn.init.kaiming_normal_(submodule.weight.data, a=0, mode='fan_in') | |
# elif init_type == 'xavier': | |
# nn.init.xavier_normal_(submodule.weight.data, gain=gain_conv) | |
elif isinstance(submodule, nn.Linear): | |
nn.init.normal_(submodule.weight.data, std=0.01) | |
elif 'skip' in str(submodule): | |
if isinstance(submodule, nn.Conv2d) or isinstance( | |
submodule, nn.ConvTranspose2d | |
): | |
nn.init.ones_(submodule.weight.data) | |
# else: | |
# classname = submodule.__class__.__name__ | |
# # print('WARNING: no init for ', classname) | |
def old2new(old_key): | |
""" | |
Converting old DRUNet keys to new UNExt style keys. | |
PATTERNS TO MATCH: | |
1. Case of downsampling blocks: | |
- for residual blocks (non-downsampling): | |
m_down3.2.res.0.weight -> m_down3.enc.2.conv1.weight | |
- for downsampling blocks: | |
m_down3.4.weight -> m_down3.downsample_strideconv.weight | |
2. Case of upsampling blocks: | |
- for upsampling: | |
m_up3.0.weight -> m_up3.upsample_convtranspose.weight | |
- for residual blocks: | |
m_up3.2.res.0.weight -> m_up3.enc.2.conv1.weight | |
3. Case for body blocks: | |
m_body.0.res.2.weight -> m_body.enc.0.conv2.weight | |
Args: | |
old_key (str): The old key from the state dictionary. | |
Returns: | |
str or None: The new key if matched, otherwise None. | |
""" | |
# Match keys with the pattern for residual blocks (downsampling) | |
match_residual = re.search(r"(m_down\d+)\.(\d+)\.res\.(\d+)", old_key) | |
if match_residual: | |
prefix = match_residual.group(1) # e.g., "m_down2" | |
index = match_residual.group(2) # e.g., "3" | |
conv_index = int(match_residual.group(3)) # e.g., "0" | |
# Determine the new conv index: 0 -> 1, 2 -> 2 | |
new_conv_index = 1 if conv_index == 0 else 2 | |
# Construct the new key | |
new_key = f"{prefix}.enc.{index}.conv{new_conv_index}.weight" | |
return new_key | |
match_residual = re.search(r"(m_up\d+)\.(\d+)\.res\.(\d+)", old_key) | |
if match_residual: | |
prefix = match_residual.group(1) # e.g., "m_down2" | |
index = int(match_residual.group(2)) # e.g., "3" | |
conv_index = int(match_residual.group(3)) # e.g., "0" | |
# Determine the new conv index: 0 -> 1, 2 -> 2 | |
new_conv_index = 1 if conv_index == 0 else 2 | |
# Construct the new key | |
new_key = f"{prefix}.enc.{index-1}.conv{new_conv_index}.weight" | |
return new_key | |
match_pool_downsample = re.search(r"m_down(\d+)\.4\.weight", old_key) | |
if match_pool_downsample: | |
index = match_pool_downsample.group(1) # e.g., "1" or "2" | |
# Construct the new key | |
new_key = f"pool{index}.weight" | |
return new_key | |
# Match keys for upsampling blocks | |
match_upsample = re.search(r"m_up(\d+)\.0\.weight", old_key) | |
if match_upsample: | |
index = match_upsample.group(1) # e.g., "1" or "2" | |
# Construct the new key | |
new_key = f"up{index}.weight" | |
return new_key | |
# Match keys for body blocks | |
match_body = re.search(r"(m_body)\.(\d+)\.res\.(\d+)\.weight", old_key) | |
if match_body: | |
prefix = match_body.group(1) # e.g., "m_body" | |
index = match_body.group(2) # e.g., "0" | |
conv_index = int(match_body.group(3)) # e.g., "2" | |
new_convindex = 1 if conv_index == 0 else 2 | |
# Construct the new key | |
new_key = f"{prefix}.enc.{index}.conv{new_convindex}.weight" | |
return new_key | |
# If no patterns match, return None | |
return None | |
def update_keyvals_headtail(old_key, old_value, init_value=None, new_key_name='m_head.conv0.weight', conditioning='base'): | |
""" | |
Converting old DRUNet keys to new UNExt style keys. | |
KEYS do not change but weight need to be 0 padded. | |
Args: | |
old_key (str): The old key from the state dictionary. | |
""" | |
if 'head' in old_key: | |
if conditioning == 'base': | |
c_in = init_value.shape[1] | |
c_in_old = old_value.shape[1] | |
# if c_in == c_in_old: | |
# new_value = old_value.detach() | |
# elif c_in < c_in_old: | |
# new_value = torch.zeros_like(init_value.detach()) | |
# new_value[:, -1:, ...] = old_value[:, -1:, ...] | |
# new_value[:, :c_in-1, ...] = old_value[:, :c_in-1, ...] | |
# if c_in == c_in_old: | |
# new_value = old_value.detach() | |
# elif c_in < c_in_old: | |
new_value = torch.zeros_like(init_value.detach()) | |
new_value[:, -2:-1, ...] = old_value[:, -1:, ...] | |
new_value[:, -1:, ...] = old_value[:, -1:, ...] | |
new_value[:, :c_in-2, ...] = old_value[:, :c_in-2, ...] | |
return {new_key_name: new_value} | |
else: | |
c_in = init_value.shape[1] | |
c_in_old = old_value.shape[1] | |
# if c_in == c_in_old - 1: | |
# new_value = old_value[:, :-1, ...].detach() | |
# elif c_in < c_in_old - 1: | |
# new_value = torch.zeros_like(init_value.detach()) | |
# new_value[:, -1:, ...] = old_value[:, -1:, ...] | |
# new_value[:, ...] = old_value[:, :c_in, ...] | |
new_value = torch.zeros_like(init_value.detach()) | |
new_value[:, -1:-2, ...] = old_value[:, -1:, ...] | |
new_value[:, -1:, ...] = old_value[:, -1:, ...] | |
new_value[:, ...] = old_value[:, :c_in, ...] | |
return {new_key_name: new_value} | |
elif 'tail' in old_key: | |
c_in = init_value.shape[0] | |
c_in_old = old_value.shape[0] | |
new_value = torch.zeros_like(init_value.detach()) | |
if c_in == c_in_old: | |
new_value = old_value.detach() | |
elif c_in < c_in_old: | |
new_value = torch.zeros_like(init_value.detach()) | |
new_value[:, ...] = old_value[:c_in, ...] | |
return {new_key_name: new_value} | |
else: | |
print(f"Key {old_key} does not contain 'head' or 'tail'.") | |
# test the network | |
if __name__ == "__main__": | |
net = UNeXt() | |
x = torch.randn(1, 3, 128, 128) | |
y = net(x, 0.1) | |
# print(y.shape) | |
# print(y) | |
# Case for diagonal physics | |
# IDEA 1: kills signal in the image of A | |
# im_emb = dec - physics.A_adjoint_A(dec) # will be 0 in the case of denoising, but not inpainting # TODO: add gains here too | |
# IDEA 2: compute norm of signal in ker of A | |
# normker = (dec - physics.A_adjoint_A(dec)).norm() / (dec.norm() + 1e-4) | |
# im_emb = normker * physics.A_adjoint(self.gain_diag_x * physics.A(dec) - self.gain_diag_y * y) # will be 0 in the case of denoising, but not inpainting # TODO: add gains here too | |
# IDEA 3: same as above but add the pinv as well | |
# normker = (dec - physics.A_adjoint_A(dec)).norm() / (dec.norm() + 1e-4) | |
# grad_term = physics.A_adjoint(self.gain_diag_x * physics.A(dec) - self.gain_diag_y * y) | |
# # pinv_term = physics.A_dagger(self.gain_diagpinv_x * physics.A(dec) - self.gain_diagpinv_y * y) | |
# if 'tomography' in physics.__class__.__name__.lower(): # or 'pansharp' in physics.__class__.__name__.lower(): | |
# pinv_term = physics.prox_l2(dec, self.gain_diagpinv_x ** 2 * Adec - self.gain_diagpinv_y ** 2 * y, gamma=1e9) | |
# else: | |
# pinv_term = physics.A_dagger(self.gain_diagpinv_x ** 2 * Adec - self.gain_diagpinv_y ** 2 * y) # TODO: do we set this to gain_gradx ? To get 0 during training too?? Better for denoising I guess | |
# im_emb = normker * (grad_term + pinv_term) # will be 0 in the case of denoising, but not inpainting # TODO: add gains here too | |
# # Mix it | |
# if hasattr(physics.noise_model, 'sigma'): | |
# sigma = physics.noise_model.sigma # SNR ? x /= sigma ** 2 | |
# snr = (y.abs().mean()) / (sigma + 1e-4) # SNR equivariant # TODO: add epsilon | |
# snr = snr[(...,) + (None,) * (im_emb.dim() - 1)] | |
# else: | |
# snr = 1e4 | |
# | |
# grad_large = emb + self.gain_diag * (1 + self.gain_noise / snr) * im_emb |