Spaces:
Sleeping
Sleeping
from pathlib import Path | |
import torch | |
from torch.func import vmap | |
from torch.utils.data import DataLoader | |
import deepinv as dinv | |
from deepinv.unfolded import unfolded_builder | |
from deepinv.utils.phantoms import RandomPhantomDataset, SheppLoganDataset | |
from deepinv.optim.optim_iterators import CPIteration, fStep, gStep | |
from deepinv.optim import Prior, DataFidelity | |
from deepinv.utils import TensorList | |
from physics.multiscale import MultiScaleLinearPhysics | |
from models.heads import Heads, Tails, InHead, OutTail, ConvChannels, SNRModule, EquivConvModule, EquivHeads | |
def get_PDNet_architecture(in_channels=[1, 2, 3], out_channels=[1, 2, 3], n_primal=3, n_dual=3, device='cuda'): | |
class PDNetIteration(CPIteration): | |
r"""Single iteration of learned primal dual. | |
We only redefine the fStep and gStep classes. | |
The forward method is inherited from the CPIteration class. | |
""" | |
def __init__(self, **kwargs): | |
super().__init__(**kwargs) | |
self.g_step = gStepPDNet(**kwargs) | |
self.f_step = fStepPDNet(**kwargs) | |
def forward( | |
self, X, cur_data_fidelity, cur_prior, cur_params, y, physics, *args, **kwargs | |
): | |
r""" | |
Single iteration of the Chambolle-Pock algorithm. | |
:param dict X: Dictionary containing the current iterate and the estimated cost. | |
:param deepinv.optim.DataFidelity cur_data_fidelity: Instance of the DataFidelity class defining the current data_fidelity. | |
:param deepinv.optim.Prior cur_prior: Instance of the Prior class defining the current prior. | |
:param dict cur_params: dictionary containing the current parameters of the algorithm. | |
:param torch.Tensor y: Input data. | |
:param deepinv.physics.Physics physics: Instance of the physics modeling the data-fidelity term. | |
:return: Dictionary `{"est": (x, ), "cost": F}` containing the updated current iterate and the estimated current cost. | |
""" | |
x_prev, z_prev, u_prev = X["est"] # x : primal, z : relaxed primal, u : dual | |
BS, C_primal, H_primal, W_primal = x_prev.shape | |
_, C_dual, H_dual, W_dual = u_prev.shape | |
n_channels = C_primal // n_primal | |
K = lambda x: torch.cat( | |
[physics.A(x[:, i * n_channels:(i + 1) * n_channels, :, :]) for i in range(n_primal)], dim=1) | |
K_adjoint = lambda x: torch.cat( | |
[physics.A_adjoint(x[:, i * n_channels:(i + 1) * n_channels, :, :]) for i in range(n_dual)], dim=1) | |
u = self.f_step(u_prev, K(z_prev), cur_data_fidelity, y, physics, n_channels, | |
cur_params) # dual update (data_fid) | |
x = self.g_step(x_prev, K_adjoint(u), cur_prior, n_channels, cur_params) # primal update (prior) | |
z = x + cur_params["beta"] * (x - x_prev) | |
F = ( | |
self.F_fn(x, cur_data_fidelity, cur_prior, cur_params, y, physics) | |
if self.has_cost | |
else None | |
) | |
return {"est": (x, z, u), "cost": F} | |
class fStepPDNet(fStep): | |
r""" | |
Dual update of the PDNet algorithm. | |
We write it as a proximal operator of the data fidelity term. | |
This proximal mapping is to be replaced by a trainable model. | |
""" | |
def __init__(self, **kwargs): | |
super().__init__(**kwargs) | |
def forward(self, x, w, cur_data_fidelity, y, physics, n_channels, *args): | |
r""" | |
:param torch.Tensor x: Current first variable :math:`u`. | |
:param torch.Tensor w: Current second variable :math:`A z`. | |
:param deepinv.optim.data_fidelity cur_data_fidelity: Instance of the DataFidelity class defining the current data fidelity term. | |
:param torch.Tensor y: Input data. | |
""" | |
return cur_data_fidelity.prox(x, w, y, n_channels) | |
class gStepPDNet(gStep): | |
r""" | |
Primal update of the PDNet algorithm. | |
We write it as a proximal operator of the prior term. | |
This proximal mapping is to be replaced by a trainable model. | |
""" | |
def __init__(self, **kwargs): | |
super().__init__(**kwargs) | |
def forward(self, x, w, cur_prior, n_channels, *args): | |
r""" | |
:param torch.Tensor x: Current first variable :math:`x`. | |
:param torch.Tensor w: Current second variable :math:`A^\top u`. | |
:param deepinv.optim.prior cur_prior: Instance of the Prior class defining the current prior. | |
""" | |
return cur_prior.prox(x, w, n_channels) | |
# %% | |
# Define the trainable prior and data fidelity terms. | |
# --------------------------------------------------- | |
# Prior and data-fidelity are respectively defined as subclass of :class:`deepinv.optim.Prior` and :class:`deepinv.optim.DataFidelity`. | |
# Their proximal operators are replaced by trainable models. | |
class PDNetPrior(Prior): | |
def __init__(self, model, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
self.model = model | |
def prox(self, x, w, n_channels): | |
# give to the model : full primal + premier de dual | |
dual_cond = w[:, 0:n_channels, :, :] | |
return self.model(x, dual_cond) | |
class PDNetDataFid(DataFidelity): | |
def __init__(self, model, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
self.model = model | |
def prox(self, x, w, y, n_channels): | |
# give to the model : full dual + deuxieme de primal + y = n_channel*n_dual + n_channel + n_channel | |
if n_primal > 1: | |
primal_cond = w[:, n_channels:(2 * n_channels), :, :] | |
else: | |
primal_cond = w[:, 0:n_channels, :, :] | |
return self.model(x, primal_cond, y) | |
# Unrolled optimization algorithm parameters | |
max_iter = 10 | |
# Set up the data fidelity term. Each layer has its own data fidelity module. | |
in_channels_dual = [in_channel * n_dual + in_channel + in_channel for in_channel in in_channels] | |
out_channels_dual = [in_channel * n_dual for in_channel in in_channels] | |
in_channels_primal = [in_channel * n_primal + in_channel for in_channel in in_channels] | |
out_channels_primal = [in_channel * n_primal for in_channel in in_channels] | |
data_fidelity = [ | |
PDNetDataFid(model=PDNet_DualBlock(in_channels=in_channels_dual, out_channels=out_channels_dual).to(device)) for | |
i in range(max_iter) | |
] | |
# Set up the trainable prior. Each layer has its own prior module. | |
prior = [ | |
PDNetPrior(model=PDNet_PrimalBlock(in_channels=in_channels_primal, out_channels=out_channels_primal).to(device)) | |
for i in range(max_iter)] | |
# %% | |
# Define the model. | |
# ------------------------------- | |
def custom_init(y, physics): | |
x0 = physics.A_dagger(y).repeat(1, n_primal, 1, 1) | |
u0 = (0 * y).repeat(1, n_dual, 1, 1) | |
return {"est": (x0, x0, u0)} | |
def custom_output(X): | |
x = X["est"][0] | |
n_channels = x.shape[1] // n_primal | |
if n_primal > 1: | |
return X["est"][0][:, n_channels:(2 * n_channels), :, :] | |
else: | |
return X["est"][0][:, 0:n_channels, :, :] | |
# %% | |
# Define the unfolded trainable model. | |
# ------------------------------------- | |
# The original paper of the learned primal dual algorithm the authors used the adjoint operator | |
# in the primal update. However, the same authors (among others) find in the paper | |
# | |
# A. Hauptmann, J. Adler, S. Arridge, O. Öktem, | |
# Multi-scale learned iterative reconstruction, | |
# IEEE Transactions on Computational Imaging 6, 843-856, 2020. | |
# | |
# that using a filtered gradient can improve both the training speed and reconstruction quality significantly. | |
# Following this approach, we use the filtered backprojection instead of the adjoint operator in the primal step. | |
model = unfolded_builder( | |
iteration=PDNetIteration(), | |
params_algo={"beta": 0.0}, | |
data_fidelity=data_fidelity, | |
prior=prior, | |
max_iter=max_iter, | |
custom_init=custom_init, | |
get_output=custom_output, | |
) | |
return model.to(device) | |
def init_weights(m): | |
if isinstance(m, torch.nn.Linear): | |
torch.torch.nn.init.xavier_uniform(m.weight) | |
m.bias.data.fill_(0.0) | |
class PDNet_PrimalBlock(torch.nn.Module): | |
r""" | |
Primal block for the Primal-Dual unfolding model. | |
From https://arxiv.org/abs/1707.06474. | |
Primal variables are images of shape (batch_size, in_channels, height, width). The input of each | |
primal block is the concatenation of the current primal variable and the backprojected dual variable along | |
the channel dimension. The output of each primal block is the current primal variable. | |
:param int in_channels: number of input channels. Default: 6. | |
:param int out_channels: number of output channels. Default: 5. | |
:param int depth: number of convolutional layers in the block. Default: 3. | |
:param bool bias: whether to use bias in convolutional layers. Default: True. | |
:param int nf: number of features in the convolutional layers. Default: 32. | |
""" | |
def __init__(self, in_channels=[1, 2, 3], out_channels=[1, 2, 3], depth=3, bias=True, nf=32): | |
super(PDNet_PrimalBlock, self).__init__() | |
self.separate_head = isinstance(in_channels, list) | |
self.depth = depth | |
self.in_conv = InHead(in_channels, nf, bias=bias) | |
# self.m_head.apply(init_weights) | |
# self.in_conv = torch.nn.Conv2d( | |
# in_channels, nf, kernel_size=3, stride=1, padding=1, bias=bias | |
# ) | |
self.in_conv.apply(init_weights) | |
self.conv_list = torch.nn.ModuleList( | |
[ | |
torch.nn.Conv2d(nf, nf, kernel_size=3, stride=1, padding=1, bias=bias) | |
for _ in range(self.depth - 2) | |
] | |
) | |
self.conv_list.apply(init_weights) | |
# self.out_conv = torch.nn.Conv2d( | |
# nf, out_channels, kernel_size=3, stride=1, padding=1, bias=bias | |
# ) | |
self.out_conv = OutTail(nf, out_channels, bias=bias) | |
self.out_conv.apply(init_weights) | |
self.nl_list = torch.nn.ModuleList([torch.nn.PReLU() for _ in range(self.depth - 1)]) | |
def forward(self, x, Atu): | |
r""" | |
Forward pass of the primal block. | |
:param torch.Tensor x: current primal variable. | |
:param torch.Tensor Atu: backprojected dual variable. | |
:return: (:class:`torch.Tensor`) the current primal variable. | |
""" | |
primal_channels = x.shape[1] | |
x_in = torch.cat((x, Atu), dim=1) | |
x_ = self.in_conv(x_in) | |
x_ = self.nl_list[0](x_) | |
for i in range(self.depth - 2): | |
x_l = self.conv_list[i](x_) | |
x_ = self.nl_list[i + 1](x_l) | |
return self.out_conv(x_, primal_channels) + x | |
class PDNet_DualBlock(torch.nn.Module): | |
r""" | |
Dual block for the Primal-Dual unfolding model. | |
From https://arxiv.org/abs/1707.06474. | |
Dual variables are images of shape (batch_size, in_channels, height, width). The input of each | |
primal block is the concatenation of the current dual variable with the projected primal variable and | |
the measurements. The output of each dual block is the current primal variable. | |
:param int in_channels: number of input channels. Default: 7. | |
:param int out_channels: number of output channels. Default: 5. | |
:param int depth: number of convolutional layers in the block. Default: 3. | |
:param bool bias: whether to use bias in convolutional layers. Default: True. | |
:param int nf: number of features in the convolutional layers. Default: 32. | |
""" | |
def __init__(self, in_channels=[1, 2, 3], out_channels=[6, 2, 3], depth=3, bias=True, nf=32): | |
super(PDNet_DualBlock, self).__init__() | |
self.depth = depth | |
self.in_conv = InHead(in_channels, nf, bias=bias) | |
# self.in_conv = torch.nn.Conv2d( | |
# in_channels, nf, kernel_size=3, stride=1, padding=1, bias=bias | |
# ) | |
self.in_conv.apply(init_weights) | |
self.conv_list = torch.nn.ModuleList( | |
[ | |
torch.nn.Conv2d(nf, nf, kernel_size=3, stride=1, padding=1, bias=bias) | |
for _ in range(self.depth - 2) | |
] | |
) | |
self.conv_list.apply(init_weights) | |
self.out_conv = OutTail(nf, out_channels, bias=bias) | |
# self.out_conv = torch.nn.Conv2d( | |
# nf, out_channels, kernel_size=3, stride=1, padding=1, bias=bias | |
# ) | |
self.out_conv.apply(init_weights) | |
self.nl_list = torch.nn.ModuleList([torch.nn.PReLU() for _ in range(self.depth - 1)]) | |
def forward(self, u, Ax_cur, y): | |
r""" | |
Forward pass of the dual block. | |
:param torch.Tensor u: current dual variable. | |
:param torch.Tensor Ax_cur: projection of the primal variable. | |
:param torch.Tensor y: measurements. | |
""" | |
dual_channels = u.shape[1] | |
x_in = torch.cat((u, Ax_cur, y), dim=1) | |
x_ = self.in_conv(x_in) | |
x_ = self.nl_list[0](x_) | |
for i in range(self.depth - 2): | |
x_l = self.conv_list[i](x_) | |
x_ = self.nl_list[i + 1](x_l) | |
return self.out_conv(x_, dual_channels) + u |