denoising / models /ram.py
mterris's picture
update
e0e7789
raw
history blame
30.3 kB
import torch
import torch.nn as nn
import torch.nn.functional as F
import deepinv as dinv
from deepinv.physics import Physics, LinearPhysics, Downsampling
from deepinv.utils import TensorList
from deepinv.utils.tensorlist import TensorList
from huggingface_hub import hf_hub_download
device = 'cuda' if torch.cuda.is_available() else 'cpu'
Tensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor
class RAM(nn.Module):
r"""
RAM model
This model is a convolutional neural network (CNN) designed for image reconstruction tasks.
:param in_channels: Number of input channels. If a list is provided, the model will have separate heads for each channel.
:param device: Device to which the model should be moved. If None, the model will be created on the default device.
:param pretrained: If True, the model will be initialized with pretrained weights.
"""
def __init__(
self,
in_channels=[1, 2, 3],
device=None,
pretrained=True,
):
super(RAM, self).__init__()
nc = [64, 128, 256, 512] # number of channels in the network
self.in_channels = in_channels
self.fact_realign = torch.nn.Parameter(torch.tensor([1.0], device=device))
self.separate_head = isinstance(in_channels, list)
if isinstance(in_channels, list):
in_channels_first = []
for i in range(len(in_channels)):
in_channels_first.append(in_channels[i] + 2)
# check if in_channels is a list
self.m_head = InHead(in_channels_first, nc[0])
self.m_down1 = BaseEncBlock(nc[0], nc[0], img_channels=in_channels, decode_upscale=1)
self.m_down2 = BaseEncBlock(nc[1], nc[1], img_channels=in_channels, decode_upscale=2)
self.m_down3 = BaseEncBlock(nc[2], nc[2], img_channels=in_channels, decode_upscale=4)
self.m_body = BaseEncBlock(nc[3], nc[3], img_channels=in_channels, decode_upscale=8)
self.m_up3 = BaseEncBlock(nc[2], nc[2], img_channels=in_channels, decode_upscale=4)
self.m_up2 = BaseEncBlock(nc[1], nc[1], img_channels=in_channels, decode_upscale=2)
self.m_up1 = BaseEncBlock(nc[0], nc[0], img_channels=in_channels, decode_upscale=1)
self.pool1 = downsample_strideconv(nc[0], nc[1], bias=False, mode="2")
self.pool2 = downsample_strideconv(nc[1], nc[2], bias=False, mode="2")
self.pool3 = downsample_strideconv(nc[2], nc[3], bias=False, mode="2")
self.up3 = upsample_convtranspose(nc[3], nc[2], bias=False, mode="2")
self.up2 = upsample_convtranspose(nc[2], nc[1], bias=False, mode="2")
self.up1 = upsample_convtranspose(nc[1], nc[0], bias=False, mode="2")
self.m_tail = OutTail(nc[0], in_channels)
# load pretrained weights from hugging face
if pretrained:
self.load_state_dict(
torch.load(hf_hub_download(repo_id="mterris/ram", filename="ram.pth.tar"), map_location=device))
if device is not None:
self.to(device)
def constant2map(self, value, x):
r"""
Converts a constant value to a map of the same size as the input tensor x.
:params float value: constant value
:params torch.Tensor x: input tensor
"""
if isinstance(value, torch.Tensor):
if value.ndim > 0:
value_map = value.view(x.size(0), 1, 1, 1)
value_map = value_map.expand(-1, 1, x.size(2), x.size(3))
else:
value_map = torch.ones(
(x.size(0), 1, x.size(2), x.size(3)), device=x.device
) * value[None, None, None, None].to(x.device)
else:
value_map = (
torch.ones((x.size(0), 1, x.size(2), x.size(3)), device=x.device)
* value
)
return value_map
def base_conditioning(self, x, sigma, gamma):
noise_level_map = self.constant2map(sigma, x)
gamma_map = self.constant2map(gamma, x)
return torch.cat((x, noise_level_map, gamma_map), 1)
def realign_input(self, x, physics, y):
r"""
Realign the input x based on the measurements y and the physics model.
Applies the proximity operator of the L2 norm with respect to the physics model.
:params torch.Tensor x: Input tensor
:params deepinv.physics.Physics physics: Physics model
:params torch.Tensor y: Measurements
"""
if hasattr(physics, "factor"):
f = physics.factor
elif hasattr(physics, "base") and hasattr(physics.base, "factor"):
f = physics.base.factor
elif hasattr(physics, "base") and hasattr(physics.base, "base") and hasattr(physics.base.base, "factor"):
f = physics.base.base.factor
else:
f = 1.0
sigma = 1e-6 # default value
if hasattr(physics.noise_model, 'sigma'):
sigma = physics.noise_model.sigma
if hasattr(physics, 'base') and hasattr(physics.base, 'noise_model') and hasattr(physics.base.noise_model,
'sigma'):
sigma = physics.base.noise_model.sigma
if hasattr(physics, 'base') and hasattr(physics.base, 'base') and hasattr(physics.base.base,
'noise_model') and hasattr(
physics.base.base.noise_model, 'sigma'):
sigma = physics.base.base.noise_model.sigma
if isinstance(y, TensorList):
num = (y[0].reshape(y[0].shape[0], -1).abs().mean(1))
else:
num = (y.reshape(y.shape[0], -1).abs().mean(1))
snr = num / (sigma + 1e-4) # SNR equivariant
gamma = 1 / (1e-4 + 1 / (
snr * f ** 2)) # TODO: check square-root / mean / check if we need to add a factor in front ?
gamma = gamma[(...,) + (None,) * (x.dim() - 1)]
model_input = physics.prox_l2(x, y, gamma=gamma * self.fact_realign)
return model_input
def forward_unet(self, x0, sigma=None, gamma=None, physics=None, y=None):
r"""
Forward pass of the UNet model.
:params torch.Tensor x0: init image
:params float sigma: Gaussian noise level
:params float gamma: Poisson noise gain
:params deepinv.physics.Physics physics: physics measurement operator
:params torch.Tensor y: measurements
"""
img_channels = x0.shape[1]
physics = MultiScaleLinearPhysics(physics, x0.shape[-3:], device=x0.device)
if self.separate_head and img_channels not in self.in_channels:
raise ValueError(
f"Input image has {img_channels} channels, but the network only have heads for {self.in_channels} channels.")
if y is not None:
x0 = self.realign_input(x0, physics, y)
x0 = self.base_conditioning(x0, sigma, gamma)
x1 = self.m_head(x0)
x1_ = self.m_down1(x1, physics=physics, y=y, img_channels=img_channels, scale=0)
x2 = self.pool1(x1_)
x3_ = self.m_down2(x2, physics=physics, y=y, img_channels=img_channels, scale=1)
x3 = self.pool2(x3_)
x4_ = self.m_down3(x3, physics=physics, y=y, img_channels=img_channels, scale=2)
x4 = self.pool3(x4_)
x = self.m_body(x4, physics=physics, y=y, img_channels=img_channels, scale=3)
x = self.up3(x + x4)
x = self.m_up3(x, physics=physics, y=y, img_channels=img_channels, scale=2)
x = self.up2(x + x3)
x = self.m_up2(x, physics=physics, y=y, img_channels=img_channels, scale=1)
x = self.up1(x + x2)
x = self.m_up1(x, physics=physics, y=y, img_channels=img_channels, scale=0)
x = self.m_tail(x + x1, img_channels)
return x
def forward(self, y=None, physics=None):
r"""
Reconstructs a signal estimate from measurements y
:param torch.tensor y: measurements
:param deepinv.physics.Physics physics: forward operator
"""
if physics is None:
physics = dinv.physics.Denoising(noise_model=dinv.physics.GaussianNoise(sigma=0.), device=y.device)
x_temp = physics.A_adjoint(y)
pad = (-x_temp.size(-2) % 8, -x_temp.size(-1) % 8)
physics = Pad(physics, pad)
x_in = physics.A_adjoint(y)
sigma = physics.noise_model.sigma if hasattr(physics.noise_model, "sigma") else 1e-3
gamma = physics.noise_model.gain if hasattr(physics.noise_model, "gain") else 1e-3
out = self.forward_unet(x_in, sigma=sigma, gamma=gamma, physics=physics, y=y)
out = physics.remove_pad(out)
return out
### --------------- MODEL ---------------
class BaseEncBlock(nn.Module):
def __init__(self, in_channels, out_channels, bias=False, nb=4, img_channels=None, decode_upscale=None):
super(BaseEncBlock, self).__init__()
self.enc = nn.ModuleList(
[
ResBlock(
in_channels,
out_channels,
bias=bias,
img_channels=img_channels,
decode_upscale=decode_upscale,
)
for _ in range(nb)
]
)
def forward(self, x, physics=None, y=None, img_channels=None, scale=0):
for i in range(len(self.enc)):
x = self.enc[i](x, physics=physics, y=y, img_channels=img_channels, scale=scale)
return x
def krylov_embeddings(y, p, factor, v=None, N=4, x_init=None):
r"""
Efficient Krylov subspace embedding computation with parallel processing.
:params torch.Tensor y: Input tensor.
:params p: An object with A and A_adjoint methods (linear operator).
:params float factor: Scaling factor.
:params torch.Tensor v: Precomputed values to subtract from Krylov sequence. Defaults to None.
:params int N: Number of Krylov iterations. Defaults to 4.
:params torch.Tensor x_init: Initial guess. Defaults to None.
"""
if x_init is None:
x = p.A_adjoint(y)
else:
x = x_init.clone() # Extract the first img_channels
norm = factor ** 2 # Precompute normalization factor
AtA = lambda u: p.A_adjoint(p.A(u)) * norm # Define the linear operator
v = v if v is not None else torch.zeros_like(x)
out = x.clone()
# Compute Krylov basis
x_k = x.clone()
for i in range(N - 1):
x_k = AtA(x_k) - v
out = torch.cat([out, x_k], dim=1)
return out
class MeasCondBlock(nn.Module):
r"""
Measurement conditioning block for the RAM model.
:param out_channels: Number of output channels.
:param img_channels: Number of input channels. If a list is provided, the model will have separate heads for each channel.
:param decode_upscale: Upscaling factor for the decoding convolution.
:param N: Number of Krylov iterations.
:param depth_encoding: Depth of the encoding convolution.
:param c_mult: Multiplier for the number of channels.
"""
def __init__(self, out_channels=64, img_channels=None, decode_upscale=None, N=4, depth_encoding=1, c_mult=1):
super(MeasCondBlock, self).__init__()
self.separate_head = isinstance(img_channels, list)
assert img_channels is not None, "decode_dimensions should be provided"
assert decode_upscale is not None, "decode_upscale should be provided"
self.N = N
self.c_mult = c_mult
self.relu_encoding = nn.ReLU(inplace=False)
self.decoding_conv = Tails(out_channels, img_channels, depth=1, scale=1, bias=False, c_mult=self.c_mult)
self.encoding_conv = Heads(img_channels, out_channels, depth=depth_encoding, scale=1, bias=False,
c_mult=self.c_mult * N, c_add=N, relu_in=False, skip_in=True)
self.gain = torch.nn.Parameter(torch.tensor([1.0]), requires_grad=True)
self.gain_gradx = torch.nn.Parameter(torch.tensor([1e-2]), requires_grad=True)
self.gain_grady = torch.nn.Parameter(torch.tensor([1e-2]), requires_grad=True)
self.gain_pinvx = torch.nn.Parameter(torch.tensor([1e-2]), requires_grad=True)
self.gain_pinvy = torch.nn.Parameter(torch.tensor([1e-2]), requires_grad=True)
def forward(self, x, y, physics, img_channels=None, scale=1):
physics.set_scale(scale)
dec = self.decoding_conv(x, img_channels)
factor = 2 ** (scale)
meas_y = krylov_embeddings(y, physics, factor, N=self.N)
meas_dec = krylov_embeddings(y, physics, factor, N=self.N, x_init=dec[:, :img_channels, ...])
for c in range(1, self.c_mult):
meas_cur = krylov_embeddings(y, physics, factor, N=self.N,
x_init=dec[:, img_channels * c:img_channels * (c + 1)])
meas_dec = torch.cat([meas_dec, meas_cur], dim=1)
meas = torch.cat([meas_y, meas_dec], dim=1)
cond = self.encoding_conv(meas)
emb = self.relu_encoding(cond)
return emb
class ResBlock(nn.Module):
r"""
Convolutional residual block.
:param in_channels: Number of input channels.
:param out_channels: Number of output channels.
:param kernel_size: Size of the convolution kernel.
:param stride: Stride of the convolution.
:param padding: Padding for the convolution.
:param bias: Whether to use bias in the convolution.
:param img_channels: Number of input channels. If a list is provided, the model will have separate heads for each channel.
:param decode_upscale: Upscaling factor for the decoding convolution.
:param head: Whether this is a head block.
:param tail: Whether this is a tail block.
:param N: Number of Krylov iterations.
:param c_mult: Multiplier for the number of channels.
:param depth_encoding: Depth of the encoding convolution.
"""
def __init__(
self,
in_channels=64,
out_channels=64,
kernel_size=3,
stride=1,
padding=1,
bias=True,
img_channels=None,
decode_upscale=None,
head=False,
tail=False,
N=2,
c_mult=2,
depth_encoding=2,
):
super(ResBlock, self).__init__()
if not head and not tail:
assert in_channels == out_channels, "Only support in_channels==out_channels."
self.separate_head = isinstance(img_channels, list)
self.is_head = head
self.is_tail = tail
if self.is_head:
self.head = InHead(img_channels, out_channels, input_layer=True)
if not self.is_head and not self.is_tail:
self.conv1 = conv(
in_channels,
out_channels,
kernel_size,
stride,
padding,
bias,
"C",
)
self.nl = nn.ReLU(inplace=True)
self.conv2 = conv(
out_channels,
out_channels,
kernel_size,
stride,
padding,
bias,
"C",
)
self.gain = torch.nn.Parameter(torch.tensor([1.0]), requires_grad=True)
self.PhysicsBlock = MeasCondBlock(out_channels=out_channels, c_mult=c_mult,
img_channels=img_channels, decode_upscale=decode_upscale,
N=N, depth_encoding=depth_encoding)
def forward(self, x, physics=None, y=None, img_channels=None, scale=0):
u = self.conv1(x)
u = self.nl(u)
u_2 = self.conv2(u)
emb_grad = self.PhysicsBlock(u, y, physics, img_channels=img_channels, scale=scale)
u_1 = self.gain * emb_grad
return x + u_2 + u_1
class InHead(torch.nn.Module):
def __init__(self, in_channels_list, out_channels, mode="", bias=False, input_layer=False):
super(InHead, self).__init__()
self.in_channels_list = in_channels_list
self.input_layer = input_layer
for i, in_channels in enumerate(in_channels_list):
conv = AffineConv2d(
in_channels=in_channels,
out_channels=out_channels,
bias=bias,
mode=mode,
kernel_size=3,
stride=1,
padding=1,
padding_mode="zeros",
)
setattr(self, f"conv{i}", conv)
def forward(self, x):
in_channels = x.size(1) - 1 if self.input_layer else x.size(1)
# find index
i = self.in_channels_list.index(in_channels)
x = getattr(self, f"conv{i}")(x)
return x
class OutTail(torch.nn.Module):
def __init__(self, in_channels, out_channels_list, mode="", bias=False):
super(OutTail, self).__init__()
self.in_channels = in_channels
self.out_channels_list = out_channels_list
for i, out_channels in enumerate(out_channels_list):
conv = AffineConv2d(
in_channels=in_channels,
out_channels=out_channels,
bias=bias,
mode=mode,
kernel_size=3,
stride=1,
padding=1,
padding_mode="zeros",
)
setattr(self, f"conv{i}", conv)
def forward(self, x, out_channels):
i = self.out_channels_list.index(out_channels)
x = getattr(self, f"conv{i}")(x)
return x
class Heads(torch.nn.Module):
def __init__(self, in_channels_list, out_channels, depth=2, scale=1, bias=True, mode="bilinear", c_mult=1, c_add=0,
relu_in=False, skip_in=False):
super(Heads, self).__init__()
self.in_channels_list = [c * (c_mult + c_add) for c in in_channels_list]
self.scale = scale
self.mode = mode
for i, in_channels in enumerate(self.in_channels_list):
setattr(self, f"head{i}",
HeadBlock(in_channels, out_channels, depth=depth, bias=bias, relu_in=relu_in, skip_in=skip_in))
if self.mode == "":
self.nl = torch.nn.ReLU(inplace=False)
if self.scale != 1:
for i, in_channels in enumerate(in_channels_list):
setattr(self, f"down{i}",
downsample_strideconv(in_channels, in_channels, bias=False, mode=str(self.scale)))
def forward(self, x):
in_channels = x.size(1)
i = self.in_channels_list.index(in_channels)
if self.scale != 1:
if self.mode == "bilinear":
x = torch.nn.functional.interpolate(x, scale_factor=1 / self.scale, mode='bilinear',
align_corners=False)
else:
x = getattr(self, f"down{i}")(x)
x = self.nl(x)
# find index
x = getattr(self, f"head{i}")(x)
return x
class Tails(torch.nn.Module):
def __init__(self, in_channels, out_channels_list, depth=2, scale=1, bias=True, mode="bilinear", c_mult=1,
relu_in=False, skip_in=False):
super(Tails, self).__init__()
self.out_channels_list = out_channels_list
self.scale = scale
for i, out_channels in enumerate(out_channels_list):
setattr(self, f"tail{i}",
HeadBlock(in_channels, out_channels * c_mult, depth=depth, bias=bias, relu_in=relu_in,
skip_in=skip_in))
self.mode = mode
if self.mode == "":
self.nl = torch.nn.ReLU(inplace=False)
if self.scale != 1:
for i, out_channels in enumerate(out_channels_list):
setattr(self, f"up{i}",
upsample_convtranspose(out_channels * c_mult, out_channels * c_mult, bias=bias,
mode=str(self.scale)))
def forward(self, x, out_channels):
i = self.out_channels_list.index(out_channels)
x = getattr(self, f"tail{i}")(x)
# find index
if self.scale != 1:
if self.mode == "bilinear":
x = torch.nn.functional.interpolate(x, scale_factor=self.scale, mode='bilinear', align_corners=False)
else:
x = getattr(self, f"up{i}")(x)
return x
class HeadBlock(torch.nn.Module):
def __init__(self, in_channels, out_channels, kernel_size=3, bias=True, depth=2, relu_in=False, skip_in=False):
super(HeadBlock, self).__init__()
padding = kernel_size // 2
c = out_channels if depth < 2 else in_channels
self.convin = torch.nn.Conv2d(in_channels, c, kernel_size, padding=padding, bias=bias)
self.zero_conv_skip = torch.nn.Conv2d(in_channels, c, 1, bias=False)
self.depth = depth
self.nl_1 = torch.nn.ReLU(inplace=False)
self.nl_2 = torch.nn.ReLU(inplace=False)
self.relu_in = relu_in
self.skip_in = skip_in
for i in range(depth - 1):
if i < depth - 2:
c_in, c = in_channels, in_channels
else:
c_in, c = in_channels, out_channels
setattr(self, f"conv1{i}", torch.nn.Conv2d(c_in, c_in, kernel_size, padding=padding, bias=bias))
setattr(self, f"conv2{i}", torch.nn.Conv2d(c_in, c, kernel_size, padding=padding, bias=bias))
setattr(self, f"skipconv{i}", torch.nn.Conv2d(c_in, c, 1, bias=False))
def forward(self, x):
if self.skip_in and self.relu_in:
x = self.nl_1(self.convin(x)) + self.zero_conv_skip(x)
elif self.skip_in and not self.relu_in:
x = self.convin(x) + self.zero_conv_skip(x)
else:
x = self.convin(x)
for i in range(self.depth - 1):
aux = getattr(self, f"conv1{i}")(x)
aux = self.nl_2(aux)
aux_0 = getattr(self, f"conv2{i}")(aux)
aux_1 = getattr(self, f"skipconv{i}")(x)
x = aux_0 + aux_1
return x
# --------------------------------------------------------------------------------------
class AffineConv2d(nn.Conv2d):
def __init__(
self,
in_channels,
out_channels,
kernel_size,
mode="affine",
bias=False,
stride=1,
padding=0,
dilation=1,
groups=1,
padding_mode="circular",
blind=True,
):
if mode == "affine": # f(a*x + 1) = a*f(x) + 1
bias = False
super().__init__(
in_channels,
out_channels,
kernel_size,
bias=bias,
stride=stride,
padding=padding,
dilation=dilation,
groups=groups,
padding_mode=padding_mode,
)
self.blind = blind
self.mode = mode
def affine(self, w):
"""returns new kernels that encode affine combinations"""
return (
w.view(self.out_channels, -1).roll(1, 1).view(w.size())
- w
+ 1 / w[0, ...].numel()
)
def forward(self, x):
if self.mode != "affine":
return super().forward(x)
else:
kernel = (
self.affine(self.weight)
if self.blind
else torch.cat(
(self.affine(self.weight[:, :-1, :, :]), self.weight[:, -1:, :, :]),
dim=1,
)
)
padding = tuple(
elt for elt in reversed(self.padding) for _ in range(2)
) # used to translate padding arg used by Conv module to the ones used by F.pad
padding_mode = (
self.padding_mode if self.padding_mode != "zeros" else "constant"
) # used to translate padding_mode arg used by Conv module to the ones used by F.pad
return F.conv2d(
F.pad(x, padding, mode=padding_mode),
kernel,
stride=self.stride,
dilation=self.dilation,
groups=self.groups,
)
"""
Functional blocks below
Parts of code borrowed from
https://github.com/cszn/DPIR/tree/master/models
https://github.com/xinntao/BasicSR
"""
from collections import OrderedDict
import torch
import torch.nn as nn
"""
# --------------------------------------------
# Advanced nn.Sequential
# https://github.com/xinntao/BasicSR
# --------------------------------------------
"""
def sequential(*args):
"""Advanced nn.Sequential.
Args:
nn.Sequential, nn.Module
Returns:
nn.Sequential
"""
if len(args) == 1:
if isinstance(args[0], OrderedDict):
raise NotImplementedError("sequential does not support OrderedDict input.")
return args[0] # No sequential is needed.
modules = []
for module in args:
if isinstance(module, nn.Sequential):
for submodule in module.children():
modules.append(submodule)
elif isinstance(module, nn.Module):
modules.append(module)
return nn.Sequential(*modules)
def conv(
in_channels=64,
out_channels=64,
kernel_size=3,
stride=1,
padding=1,
bias=True,
mode="CBR",
):
L = []
for t in mode:
if t == "C":
L.append(
nn.Conv2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
bias=bias,
)
)
elif t == "T":
L.append(
nn.ConvTranspose2d(
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=stride,
padding=padding,
bias=bias,
)
)
elif t == "R":
L.append(nn.ReLU(inplace=True))
else:
raise NotImplementedError("Undefined type: ".format(t))
return sequential(*L)
# --------------------------------------------
# convTranspose (+ relu)
# --------------------------------------------
def upsample_convtranspose(
in_channels=64,
out_channels=3,
padding=0,
bias=True,
mode="2R",
):
assert len(mode) < 4 and mode[0] in [
"2",
"3",
"4",
"8",
], "mode examples: 2, 2R, 2BR, 3, ..., 4BR."
kernel_size = int(mode[0])
stride = int(mode[0])
mode = mode.replace(mode[0], "T")
up1 = conv(
in_channels,
out_channels,
kernel_size,
stride,
padding,
bias,
mode,
)
return up1
def downsample_strideconv(
in_channels=64,
out_channels=64,
padding=0,
bias=True,
mode="2R",
):
assert len(mode) < 4 and mode[0] in [
"2",
"3",
"4",
"8",
], "mode examples: 2, 2R, 2BR, 3, ..., 4BR."
kernel_size = int(mode[0])
stride = int(mode[0])
mode = mode.replace(mode[0], "C")
down1 = conv(
in_channels,
out_channels,
kernel_size,
stride,
padding,
bias,
mode,
)
return down1
class Upsampling(Downsampling):
def A(self, x, **kwargs):
return super().A_adjoint(x, **kwargs)
def A_adjoint(self, y, **kwargs):
return super().A(y, **kwargs)
def prox_l2(self, z, y, gamma, **kwargs):
return super().prox_l2(z, y, gamma, **kwargs)
class MultiScalePhysics(Physics):
def __init__(self, physics, img_shape, filter="sinc", scales=[2, 4, 8], device='cpu', **kwargs):
super().__init__(noise_model=physics.noise_model, **kwargs)
self.base = physics
self.scales = scales
self.img_shape = img_shape
self.Upsamplings = [Upsampling(img_size=img_shape, filter=filter, factor=factor, device=device) for factor in
scales]
self.scale = 0
def set_scale(self, scale):
if scale is not None:
self.scale = scale
def A(self, x, scale=None, **kwargs):
self.set_scale(scale)
if self.scale == 0:
return self.base.A(x, **kwargs)
else:
return self.base.A(self.Upsamplings[self.scale - 1].A(x), **kwargs)
def downsample(self, x, scale=None):
self.set_scale(scale)
if self.scale == 0:
return x
else:
return self.Upsamplings[self.scale - 1].A_adjoint(x)
def upsample(self, x, scale=None):
self.set_scale(scale)
if self.scale == 0:
return x
else:
return self.Upsamplings[self.scale - 1].A(x)
def update_parameters(self, **kwargs):
self.base.update_parameters(**kwargs)
class MultiScaleLinearPhysics(MultiScalePhysics, LinearPhysics):
def __init__(self, physics, img_shape, filter="sinc", scales=[2, 4, 8], **kwargs):
super().__init__(physics=physics, img_shape=img_shape, filter=filter, scales=scales, **kwargs)
def A_adjoint(self, y, scale=None, **kwargs):
self.set_scale(scale)
y = self.base.A_adjoint(y, **kwargs)
if self.scale == 0:
return y
else:
return self.Upsamplings[self.scale - 1].A_adjoint(y)
class Pad(LinearPhysics):
def __init__(self, physics, pad):
super().__init__(noise_model=physics.noise_model)
self.base = physics
self.pad = pad
def A(self, x):
return self.base.A(x[..., self.pad[0]:, self.pad[1]:])
def A_adjoint(self, y):
y = self.base.A_adjoint(y)
y = torch.nn.functional.pad(y, (self.pad[1], 0, self.pad[0], 0))
return y
def remove_pad(self, x):
return x[..., self.pad[0]:, self.pad[1]:]
def update_parameters(self, **kwargs):
self.base.update_parameters(**kwargs)