Spaces:
Starting
on
T4
Starting
on
T4
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import deepinv as dinv | |
from deepinv.physics import Physics, LinearPhysics, Downsampling | |
from deepinv.utils import TensorList | |
from deepinv.utils.tensorlist import TensorList | |
from huggingface_hub import hf_hub_download | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
Tensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor | |
class RAM(nn.Module): | |
r""" | |
RAM model | |
This model is a convolutional neural network (CNN) designed for image reconstruction tasks. | |
:param in_channels: Number of input channels. If a list is provided, the model will have separate heads for each channel. | |
:param device: Device to which the model should be moved. If None, the model will be created on the default device. | |
:param pretrained: If True, the model will be initialized with pretrained weights. | |
""" | |
def __init__( | |
self, | |
in_channels=[1, 2, 3], | |
device=None, | |
pretrained=True, | |
): | |
super(RAM, self).__init__() | |
nc = [64, 128, 256, 512] # number of channels in the network | |
self.in_channels = in_channels | |
self.fact_realign = torch.nn.Parameter(torch.tensor([1.0], device=device)) | |
self.separate_head = isinstance(in_channels, list) | |
if isinstance(in_channels, list): | |
in_channels_first = [] | |
for i in range(len(in_channels)): | |
in_channels_first.append(in_channels[i] + 2) | |
# check if in_channels is a list | |
self.m_head = InHead(in_channels_first, nc[0]) | |
self.m_down1 = BaseEncBlock(nc[0], nc[0], img_channels=in_channels, decode_upscale=1) | |
self.m_down2 = BaseEncBlock(nc[1], nc[1], img_channels=in_channels, decode_upscale=2) | |
self.m_down3 = BaseEncBlock(nc[2], nc[2], img_channels=in_channels, decode_upscale=4) | |
self.m_body = BaseEncBlock(nc[3], nc[3], img_channels=in_channels, decode_upscale=8) | |
self.m_up3 = BaseEncBlock(nc[2], nc[2], img_channels=in_channels, decode_upscale=4) | |
self.m_up2 = BaseEncBlock(nc[1], nc[1], img_channels=in_channels, decode_upscale=2) | |
self.m_up1 = BaseEncBlock(nc[0], nc[0], img_channels=in_channels, decode_upscale=1) | |
self.pool1 = downsample_strideconv(nc[0], nc[1], bias=False, mode="2") | |
self.pool2 = downsample_strideconv(nc[1], nc[2], bias=False, mode="2") | |
self.pool3 = downsample_strideconv(nc[2], nc[3], bias=False, mode="2") | |
self.up3 = upsample_convtranspose(nc[3], nc[2], bias=False, mode="2") | |
self.up2 = upsample_convtranspose(nc[2], nc[1], bias=False, mode="2") | |
self.up1 = upsample_convtranspose(nc[1], nc[0], bias=False, mode="2") | |
self.m_tail = OutTail(nc[0], in_channels) | |
# load pretrained weights from hugging face | |
if pretrained: | |
self.load_state_dict( | |
torch.load(hf_hub_download(repo_id="mterris/ram", filename="ram.pth.tar"), map_location=device)) | |
if device is not None: | |
self.to(device) | |
def constant2map(self, value, x): | |
r""" | |
Converts a constant value to a map of the same size as the input tensor x. | |
:params float value: constant value | |
:params torch.Tensor x: input tensor | |
""" | |
if isinstance(value, torch.Tensor): | |
if value.ndim > 0: | |
value_map = value.view(x.size(0), 1, 1, 1) | |
value_map = value_map.expand(-1, 1, x.size(2), x.size(3)) | |
else: | |
value_map = torch.ones( | |
(x.size(0), 1, x.size(2), x.size(3)), device=x.device | |
) * value[None, None, None, None].to(x.device) | |
else: | |
value_map = ( | |
torch.ones((x.size(0), 1, x.size(2), x.size(3)), device=x.device) | |
* value | |
) | |
return value_map | |
def base_conditioning(self, x, sigma, gamma): | |
noise_level_map = self.constant2map(sigma, x) | |
gamma_map = self.constant2map(gamma, x) | |
return torch.cat((x, noise_level_map, gamma_map), 1) | |
def realign_input(self, x, physics, y): | |
r""" | |
Realign the input x based on the measurements y and the physics model. | |
Applies the proximity operator of the L2 norm with respect to the physics model. | |
:params torch.Tensor x: Input tensor | |
:params deepinv.physics.Physics physics: Physics model | |
:params torch.Tensor y: Measurements | |
""" | |
if hasattr(physics, "factor"): | |
f = physics.factor | |
elif hasattr(physics, "base") and hasattr(physics.base, "factor"): | |
f = physics.base.factor | |
elif hasattr(physics, "base") and hasattr(physics.base, "base") and hasattr(physics.base.base, "factor"): | |
f = physics.base.base.factor | |
else: | |
f = 1.0 | |
sigma = 1e-6 # default value | |
if hasattr(physics.noise_model, 'sigma'): | |
sigma = physics.noise_model.sigma | |
if hasattr(physics, 'base') and hasattr(physics.base, 'noise_model') and hasattr(physics.base.noise_model, | |
'sigma'): | |
sigma = physics.base.noise_model.sigma | |
if hasattr(physics, 'base') and hasattr(physics.base, 'base') and hasattr(physics.base.base, | |
'noise_model') and hasattr( | |
physics.base.base.noise_model, 'sigma'): | |
sigma = physics.base.base.noise_model.sigma | |
if isinstance(y, TensorList): | |
num = (y[0].reshape(y[0].shape[0], -1).abs().mean(1)) | |
else: | |
num = (y.reshape(y.shape[0], -1).abs().mean(1)) | |
snr = num / (sigma + 1e-4) # SNR equivariant | |
gamma = 1 / (1e-4 + 1 / ( | |
snr * f ** 2)) # TODO: check square-root / mean / check if we need to add a factor in front ? | |
gamma = gamma[(...,) + (None,) * (x.dim() - 1)] | |
model_input = physics.prox_l2(x, y, gamma=gamma * self.fact_realign) | |
return model_input | |
def forward_unet(self, x0, sigma=None, gamma=None, physics=None, y=None): | |
r""" | |
Forward pass of the UNet model. | |
:params torch.Tensor x0: init image | |
:params float sigma: Gaussian noise level | |
:params float gamma: Poisson noise gain | |
:params deepinv.physics.Physics physics: physics measurement operator | |
:params torch.Tensor y: measurements | |
""" | |
img_channels = x0.shape[1] | |
physics = MultiScaleLinearPhysics(physics, x0.shape[-3:], device=x0.device) | |
if self.separate_head and img_channels not in self.in_channels: | |
raise ValueError( | |
f"Input image has {img_channels} channels, but the network only have heads for {self.in_channels} channels.") | |
if y is not None: | |
x0 = self.realign_input(x0, physics, y) | |
x0 = self.base_conditioning(x0, sigma, gamma) | |
x1 = self.m_head(x0) | |
x1_ = self.m_down1(x1, physics=physics, y=y, img_channels=img_channels, scale=0) | |
x2 = self.pool1(x1_) | |
x3_ = self.m_down2(x2, physics=physics, y=y, img_channels=img_channels, scale=1) | |
x3 = self.pool2(x3_) | |
x4_ = self.m_down3(x3, physics=physics, y=y, img_channels=img_channels, scale=2) | |
x4 = self.pool3(x4_) | |
x = self.m_body(x4, physics=physics, y=y, img_channels=img_channels, scale=3) | |
x = self.up3(x + x4) | |
x = self.m_up3(x, physics=physics, y=y, img_channels=img_channels, scale=2) | |
x = self.up2(x + x3) | |
x = self.m_up2(x, physics=physics, y=y, img_channels=img_channels, scale=1) | |
x = self.up1(x + x2) | |
x = self.m_up1(x, physics=physics, y=y, img_channels=img_channels, scale=0) | |
x = self.m_tail(x + x1, img_channels) | |
return x | |
def forward(self, y=None, physics=None): | |
r""" | |
Reconstructs a signal estimate from measurements y | |
:param torch.tensor y: measurements | |
:param deepinv.physics.Physics physics: forward operator | |
""" | |
if physics is None: | |
physics = dinv.physics.Denoising(noise_model=dinv.physics.GaussianNoise(sigma=0.), device=y.device) | |
x_temp = physics.A_adjoint(y) | |
pad = (-x_temp.size(-2) % 8, -x_temp.size(-1) % 8) | |
physics = Pad(physics, pad) | |
x_in = physics.A_adjoint(y) | |
sigma = physics.noise_model.sigma if hasattr(physics.noise_model, "sigma") else 1e-3 | |
gamma = physics.noise_model.gain if hasattr(physics.noise_model, "gain") else 1e-3 | |
out = self.forward_unet(x_in, sigma=sigma, gamma=gamma, physics=physics, y=y) | |
out = physics.remove_pad(out) | |
return out | |
### --------------- MODEL --------------- | |
class BaseEncBlock(nn.Module): | |
def __init__(self, in_channels, out_channels, bias=False, nb=4, img_channels=None, decode_upscale=None): | |
super(BaseEncBlock, self).__init__() | |
self.enc = nn.ModuleList( | |
[ | |
ResBlock( | |
in_channels, | |
out_channels, | |
bias=bias, | |
img_channels=img_channels, | |
decode_upscale=decode_upscale, | |
) | |
for _ in range(nb) | |
] | |
) | |
def forward(self, x, physics=None, y=None, img_channels=None, scale=0): | |
for i in range(len(self.enc)): | |
x = self.enc[i](x, physics=physics, y=y, img_channels=img_channels, scale=scale) | |
return x | |
def krylov_embeddings(y, p, factor, v=None, N=4, x_init=None): | |
r""" | |
Efficient Krylov subspace embedding computation with parallel processing. | |
:params torch.Tensor y: Input tensor. | |
:params p: An object with A and A_adjoint methods (linear operator). | |
:params float factor: Scaling factor. | |
:params torch.Tensor v: Precomputed values to subtract from Krylov sequence. Defaults to None. | |
:params int N: Number of Krylov iterations. Defaults to 4. | |
:params torch.Tensor x_init: Initial guess. Defaults to None. | |
""" | |
if x_init is None: | |
x = p.A_adjoint(y) | |
else: | |
x = x_init.clone() # Extract the first img_channels | |
norm = factor ** 2 # Precompute normalization factor | |
AtA = lambda u: p.A_adjoint(p.A(u)) * norm # Define the linear operator | |
v = v if v is not None else torch.zeros_like(x) | |
out = x.clone() | |
# Compute Krylov basis | |
x_k = x.clone() | |
for i in range(N - 1): | |
x_k = AtA(x_k) - v | |
out = torch.cat([out, x_k], dim=1) | |
return out | |
class MeasCondBlock(nn.Module): | |
r""" | |
Measurement conditioning block for the RAM model. | |
:param out_channels: Number of output channels. | |
:param img_channels: Number of input channels. If a list is provided, the model will have separate heads for each channel. | |
:param decode_upscale: Upscaling factor for the decoding convolution. | |
:param N: Number of Krylov iterations. | |
:param depth_encoding: Depth of the encoding convolution. | |
:param c_mult: Multiplier for the number of channels. | |
""" | |
def __init__(self, out_channels=64, img_channels=None, decode_upscale=None, N=4, depth_encoding=1, c_mult=1): | |
super(MeasCondBlock, self).__init__() | |
self.separate_head = isinstance(img_channels, list) | |
assert img_channels is not None, "decode_dimensions should be provided" | |
assert decode_upscale is not None, "decode_upscale should be provided" | |
self.N = N | |
self.c_mult = c_mult | |
self.relu_encoding = nn.ReLU(inplace=False) | |
self.decoding_conv = Tails(out_channels, img_channels, depth=1, scale=1, bias=False, c_mult=self.c_mult) | |
self.encoding_conv = Heads(img_channels, out_channels, depth=depth_encoding, scale=1, bias=False, | |
c_mult=self.c_mult * N, c_add=N, relu_in=False, skip_in=True) | |
self.gain = torch.nn.Parameter(torch.tensor([1.0]), requires_grad=True) | |
self.gain_gradx = torch.nn.Parameter(torch.tensor([1e-2]), requires_grad=True) | |
self.gain_grady = torch.nn.Parameter(torch.tensor([1e-2]), requires_grad=True) | |
self.gain_pinvx = torch.nn.Parameter(torch.tensor([1e-2]), requires_grad=True) | |
self.gain_pinvy = torch.nn.Parameter(torch.tensor([1e-2]), requires_grad=True) | |
def forward(self, x, y, physics, img_channels=None, scale=1): | |
physics.set_scale(scale) | |
dec = self.decoding_conv(x, img_channels) | |
factor = 2 ** (scale) | |
meas_y = krylov_embeddings(y, physics, factor, N=self.N) | |
meas_dec = krylov_embeddings(y, physics, factor, N=self.N, x_init=dec[:, :img_channels, ...]) | |
for c in range(1, self.c_mult): | |
meas_cur = krylov_embeddings(y, physics, factor, N=self.N, | |
x_init=dec[:, img_channels * c:img_channels * (c + 1)]) | |
meas_dec = torch.cat([meas_dec, meas_cur], dim=1) | |
meas = torch.cat([meas_y, meas_dec], dim=1) | |
cond = self.encoding_conv(meas) | |
emb = self.relu_encoding(cond) | |
return emb | |
class ResBlock(nn.Module): | |
r""" | |
Convolutional residual block. | |
:param in_channels: Number of input channels. | |
:param out_channels: Number of output channels. | |
:param kernel_size: Size of the convolution kernel. | |
:param stride: Stride of the convolution. | |
:param padding: Padding for the convolution. | |
:param bias: Whether to use bias in the convolution. | |
:param img_channels: Number of input channels. If a list is provided, the model will have separate heads for each channel. | |
:param decode_upscale: Upscaling factor for the decoding convolution. | |
:param head: Whether this is a head block. | |
:param tail: Whether this is a tail block. | |
:param N: Number of Krylov iterations. | |
:param c_mult: Multiplier for the number of channels. | |
:param depth_encoding: Depth of the encoding convolution. | |
""" | |
def __init__( | |
self, | |
in_channels=64, | |
out_channels=64, | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
bias=True, | |
img_channels=None, | |
decode_upscale=None, | |
head=False, | |
tail=False, | |
N=2, | |
c_mult=2, | |
depth_encoding=2, | |
): | |
super(ResBlock, self).__init__() | |
if not head and not tail: | |
assert in_channels == out_channels, "Only support in_channels==out_channels." | |
self.separate_head = isinstance(img_channels, list) | |
self.is_head = head | |
self.is_tail = tail | |
if self.is_head: | |
self.head = InHead(img_channels, out_channels, input_layer=True) | |
if not self.is_head and not self.is_tail: | |
self.conv1 = conv( | |
in_channels, | |
out_channels, | |
kernel_size, | |
stride, | |
padding, | |
bias, | |
"C", | |
) | |
self.nl = nn.ReLU(inplace=True) | |
self.conv2 = conv( | |
out_channels, | |
out_channels, | |
kernel_size, | |
stride, | |
padding, | |
bias, | |
"C", | |
) | |
self.gain = torch.nn.Parameter(torch.tensor([1.0]), requires_grad=True) | |
self.PhysicsBlock = MeasCondBlock(out_channels=out_channels, c_mult=c_mult, | |
img_channels=img_channels, decode_upscale=decode_upscale, | |
N=N, depth_encoding=depth_encoding) | |
def forward(self, x, physics=None, y=None, img_channels=None, scale=0): | |
u = self.conv1(x) | |
u = self.nl(u) | |
u_2 = self.conv2(u) | |
emb_grad = self.PhysicsBlock(u, y, physics, img_channels=img_channels, scale=scale) | |
u_1 = self.gain * emb_grad | |
return x + u_2 + u_1 | |
class InHead(torch.nn.Module): | |
def __init__(self, in_channels_list, out_channels, mode="", bias=False, input_layer=False): | |
super(InHead, self).__init__() | |
self.in_channels_list = in_channels_list | |
self.input_layer = input_layer | |
for i, in_channels in enumerate(in_channels_list): | |
conv = AffineConv2d( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
bias=bias, | |
mode=mode, | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
padding_mode="zeros", | |
) | |
setattr(self, f"conv{i}", conv) | |
def forward(self, x): | |
in_channels = x.size(1) - 1 if self.input_layer else x.size(1) | |
# find index | |
i = self.in_channels_list.index(in_channels) | |
x = getattr(self, f"conv{i}")(x) | |
return x | |
class OutTail(torch.nn.Module): | |
def __init__(self, in_channels, out_channels_list, mode="", bias=False): | |
super(OutTail, self).__init__() | |
self.in_channels = in_channels | |
self.out_channels_list = out_channels_list | |
for i, out_channels in enumerate(out_channels_list): | |
conv = AffineConv2d( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
bias=bias, | |
mode=mode, | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
padding_mode="zeros", | |
) | |
setattr(self, f"conv{i}", conv) | |
def forward(self, x, out_channels): | |
i = self.out_channels_list.index(out_channels) | |
x = getattr(self, f"conv{i}")(x) | |
return x | |
class Heads(torch.nn.Module): | |
def __init__(self, in_channels_list, out_channels, depth=2, scale=1, bias=True, mode="bilinear", c_mult=1, c_add=0, | |
relu_in=False, skip_in=False): | |
super(Heads, self).__init__() | |
self.in_channels_list = [c * (c_mult + c_add) for c in in_channels_list] | |
self.scale = scale | |
self.mode = mode | |
for i, in_channels in enumerate(self.in_channels_list): | |
setattr(self, f"head{i}", | |
HeadBlock(in_channels, out_channels, depth=depth, bias=bias, relu_in=relu_in, skip_in=skip_in)) | |
if self.mode == "": | |
self.nl = torch.nn.ReLU(inplace=False) | |
if self.scale != 1: | |
for i, in_channels in enumerate(in_channels_list): | |
setattr(self, f"down{i}", | |
downsample_strideconv(in_channels, in_channels, bias=False, mode=str(self.scale))) | |
def forward(self, x): | |
in_channels = x.size(1) | |
i = self.in_channels_list.index(in_channels) | |
if self.scale != 1: | |
if self.mode == "bilinear": | |
x = torch.nn.functional.interpolate(x, scale_factor=1 / self.scale, mode='bilinear', | |
align_corners=False) | |
else: | |
x = getattr(self, f"down{i}")(x) | |
x = self.nl(x) | |
# find index | |
x = getattr(self, f"head{i}")(x) | |
return x | |
class Tails(torch.nn.Module): | |
def __init__(self, in_channels, out_channels_list, depth=2, scale=1, bias=True, mode="bilinear", c_mult=1, | |
relu_in=False, skip_in=False): | |
super(Tails, self).__init__() | |
self.out_channels_list = out_channels_list | |
self.scale = scale | |
for i, out_channels in enumerate(out_channels_list): | |
setattr(self, f"tail{i}", | |
HeadBlock(in_channels, out_channels * c_mult, depth=depth, bias=bias, relu_in=relu_in, | |
skip_in=skip_in)) | |
self.mode = mode | |
if self.mode == "": | |
self.nl = torch.nn.ReLU(inplace=False) | |
if self.scale != 1: | |
for i, out_channels in enumerate(out_channels_list): | |
setattr(self, f"up{i}", | |
upsample_convtranspose(out_channels * c_mult, out_channels * c_mult, bias=bias, | |
mode=str(self.scale))) | |
def forward(self, x, out_channels): | |
i = self.out_channels_list.index(out_channels) | |
x = getattr(self, f"tail{i}")(x) | |
# find index | |
if self.scale != 1: | |
if self.mode == "bilinear": | |
x = torch.nn.functional.interpolate(x, scale_factor=self.scale, mode='bilinear', align_corners=False) | |
else: | |
x = getattr(self, f"up{i}")(x) | |
return x | |
class HeadBlock(torch.nn.Module): | |
def __init__(self, in_channels, out_channels, kernel_size=3, bias=True, depth=2, relu_in=False, skip_in=False): | |
super(HeadBlock, self).__init__() | |
padding = kernel_size // 2 | |
c = out_channels if depth < 2 else in_channels | |
self.convin = torch.nn.Conv2d(in_channels, c, kernel_size, padding=padding, bias=bias) | |
self.zero_conv_skip = torch.nn.Conv2d(in_channels, c, 1, bias=False) | |
self.depth = depth | |
self.nl_1 = torch.nn.ReLU(inplace=False) | |
self.nl_2 = torch.nn.ReLU(inplace=False) | |
self.relu_in = relu_in | |
self.skip_in = skip_in | |
for i in range(depth - 1): | |
if i < depth - 2: | |
c_in, c = in_channels, in_channels | |
else: | |
c_in, c = in_channels, out_channels | |
setattr(self, f"conv1{i}", torch.nn.Conv2d(c_in, c_in, kernel_size, padding=padding, bias=bias)) | |
setattr(self, f"conv2{i}", torch.nn.Conv2d(c_in, c, kernel_size, padding=padding, bias=bias)) | |
setattr(self, f"skipconv{i}", torch.nn.Conv2d(c_in, c, 1, bias=False)) | |
def forward(self, x): | |
if self.skip_in and self.relu_in: | |
x = self.nl_1(self.convin(x)) + self.zero_conv_skip(x) | |
elif self.skip_in and not self.relu_in: | |
x = self.convin(x) + self.zero_conv_skip(x) | |
else: | |
x = self.convin(x) | |
for i in range(self.depth - 1): | |
aux = getattr(self, f"conv1{i}")(x) | |
aux = self.nl_2(aux) | |
aux_0 = getattr(self, f"conv2{i}")(aux) | |
aux_1 = getattr(self, f"skipconv{i}")(x) | |
x = aux_0 + aux_1 | |
return x | |
# -------------------------------------------------------------------------------------- | |
class AffineConv2d(nn.Conv2d): | |
def __init__( | |
self, | |
in_channels, | |
out_channels, | |
kernel_size, | |
mode="affine", | |
bias=False, | |
stride=1, | |
padding=0, | |
dilation=1, | |
groups=1, | |
padding_mode="circular", | |
blind=True, | |
): | |
if mode == "affine": # f(a*x + 1) = a*f(x) + 1 | |
bias = False | |
super().__init__( | |
in_channels, | |
out_channels, | |
kernel_size, | |
bias=bias, | |
stride=stride, | |
padding=padding, | |
dilation=dilation, | |
groups=groups, | |
padding_mode=padding_mode, | |
) | |
self.blind = blind | |
self.mode = mode | |
def affine(self, w): | |
"""returns new kernels that encode affine combinations""" | |
return ( | |
w.view(self.out_channels, -1).roll(1, 1).view(w.size()) | |
- w | |
+ 1 / w[0, ...].numel() | |
) | |
def forward(self, x): | |
if self.mode != "affine": | |
return super().forward(x) | |
else: | |
kernel = ( | |
self.affine(self.weight) | |
if self.blind | |
else torch.cat( | |
(self.affine(self.weight[:, :-1, :, :]), self.weight[:, -1:, :, :]), | |
dim=1, | |
) | |
) | |
padding = tuple( | |
elt for elt in reversed(self.padding) for _ in range(2) | |
) # used to translate padding arg used by Conv module to the ones used by F.pad | |
padding_mode = ( | |
self.padding_mode if self.padding_mode != "zeros" else "constant" | |
) # used to translate padding_mode arg used by Conv module to the ones used by F.pad | |
return F.conv2d( | |
F.pad(x, padding, mode=padding_mode), | |
kernel, | |
stride=self.stride, | |
dilation=self.dilation, | |
groups=self.groups, | |
) | |
""" | |
Functional blocks below | |
Parts of code borrowed from | |
https://github.com/cszn/DPIR/tree/master/models | |
https://github.com/xinntao/BasicSR | |
""" | |
from collections import OrderedDict | |
import torch | |
import torch.nn as nn | |
""" | |
# -------------------------------------------- | |
# Advanced nn.Sequential | |
# https://github.com/xinntao/BasicSR | |
# -------------------------------------------- | |
""" | |
def sequential(*args): | |
"""Advanced nn.Sequential. | |
Args: | |
nn.Sequential, nn.Module | |
Returns: | |
nn.Sequential | |
""" | |
if len(args) == 1: | |
if isinstance(args[0], OrderedDict): | |
raise NotImplementedError("sequential does not support OrderedDict input.") | |
return args[0] # No sequential is needed. | |
modules = [] | |
for module in args: | |
if isinstance(module, nn.Sequential): | |
for submodule in module.children(): | |
modules.append(submodule) | |
elif isinstance(module, nn.Module): | |
modules.append(module) | |
return nn.Sequential(*modules) | |
def conv( | |
in_channels=64, | |
out_channels=64, | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
bias=True, | |
mode="CBR", | |
): | |
L = [] | |
for t in mode: | |
if t == "C": | |
L.append( | |
nn.Conv2d( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
kernel_size=kernel_size, | |
stride=stride, | |
padding=padding, | |
bias=bias, | |
) | |
) | |
elif t == "T": | |
L.append( | |
nn.ConvTranspose2d( | |
in_channels=in_channels, | |
out_channels=out_channels, | |
kernel_size=kernel_size, | |
stride=stride, | |
padding=padding, | |
bias=bias, | |
) | |
) | |
elif t == "R": | |
L.append(nn.ReLU(inplace=True)) | |
else: | |
raise NotImplementedError("Undefined type: ".format(t)) | |
return sequential(*L) | |
# -------------------------------------------- | |
# convTranspose (+ relu) | |
# -------------------------------------------- | |
def upsample_convtranspose( | |
in_channels=64, | |
out_channels=3, | |
padding=0, | |
bias=True, | |
mode="2R", | |
): | |
assert len(mode) < 4 and mode[0] in [ | |
"2", | |
"3", | |
"4", | |
"8", | |
], "mode examples: 2, 2R, 2BR, 3, ..., 4BR." | |
kernel_size = int(mode[0]) | |
stride = int(mode[0]) | |
mode = mode.replace(mode[0], "T") | |
up1 = conv( | |
in_channels, | |
out_channels, | |
kernel_size, | |
stride, | |
padding, | |
bias, | |
mode, | |
) | |
return up1 | |
def downsample_strideconv( | |
in_channels=64, | |
out_channels=64, | |
padding=0, | |
bias=True, | |
mode="2R", | |
): | |
assert len(mode) < 4 and mode[0] in [ | |
"2", | |
"3", | |
"4", | |
"8", | |
], "mode examples: 2, 2R, 2BR, 3, ..., 4BR." | |
kernel_size = int(mode[0]) | |
stride = int(mode[0]) | |
mode = mode.replace(mode[0], "C") | |
down1 = conv( | |
in_channels, | |
out_channels, | |
kernel_size, | |
stride, | |
padding, | |
bias, | |
mode, | |
) | |
return down1 | |
class Upsampling(Downsampling): | |
def A(self, x, **kwargs): | |
return super().A_adjoint(x, **kwargs) | |
def A_adjoint(self, y, **kwargs): | |
return super().A(y, **kwargs) | |
def prox_l2(self, z, y, gamma, **kwargs): | |
return super().prox_l2(z, y, gamma, **kwargs) | |
class MultiScalePhysics(Physics): | |
def __init__(self, physics, img_shape, filter="sinc", scales=[2, 4, 8], device='cpu', **kwargs): | |
super().__init__(noise_model=physics.noise_model, **kwargs) | |
self.base = physics | |
self.scales = scales | |
self.img_shape = img_shape | |
self.Upsamplings = [Upsampling(img_size=img_shape, filter=filter, factor=factor, device=device) for factor in | |
scales] | |
self.scale = 0 | |
def set_scale(self, scale): | |
if scale is not None: | |
self.scale = scale | |
def A(self, x, scale=None, **kwargs): | |
self.set_scale(scale) | |
if self.scale == 0: | |
return self.base.A(x, **kwargs) | |
else: | |
return self.base.A(self.Upsamplings[self.scale - 1].A(x), **kwargs) | |
def downsample(self, x, scale=None): | |
self.set_scale(scale) | |
if self.scale == 0: | |
return x | |
else: | |
return self.Upsamplings[self.scale - 1].A_adjoint(x) | |
def upsample(self, x, scale=None): | |
self.set_scale(scale) | |
if self.scale == 0: | |
return x | |
else: | |
return self.Upsamplings[self.scale - 1].A(x) | |
def update_parameters(self, **kwargs): | |
self.base.update_parameters(**kwargs) | |
class MultiScaleLinearPhysics(MultiScalePhysics, LinearPhysics): | |
def __init__(self, physics, img_shape, filter="sinc", scales=[2, 4, 8], **kwargs): | |
super().__init__(physics=physics, img_shape=img_shape, filter=filter, scales=scales, **kwargs) | |
def A_adjoint(self, y, scale=None, **kwargs): | |
self.set_scale(scale) | |
y = self.base.A_adjoint(y, **kwargs) | |
if self.scale == 0: | |
return y | |
else: | |
return self.Upsamplings[self.scale - 1].A_adjoint(y) | |
class Pad(LinearPhysics): | |
def __init__(self, physics, pad): | |
super().__init__(noise_model=physics.noise_model) | |
self.base = physics | |
self.pad = pad | |
def A(self, x): | |
return self.base.A(x[..., self.pad[0]:, self.pad[1]:]) | |
def A_adjoint(self, y): | |
y = self.base.A_adjoint(y) | |
y = torch.nn.functional.pad(y, (self.pad[1], 0, self.pad[0], 0)) | |
return y | |
def remove_pad(self, x): | |
return x[..., self.pad[0]:, self.pad[1]:] | |
def update_parameters(self, **kwargs): | |
self.base.update_parameters(**kwargs) |