import torch import torch.nn as nn import torch.nn.functional as F import deepinv as dinv from deepinv.physics import Physics, LinearPhysics, Downsampling from deepinv.utils import TensorList from deepinv.utils.tensorlist import TensorList from huggingface_hub import hf_hub_download device = 'cuda' if torch.cuda.is_available() else 'cpu' Tensor = torch.cuda.FloatTensor if torch.cuda.is_available() else torch.FloatTensor class RAM(nn.Module): r""" RAM model This model is a convolutional neural network (CNN) designed for image reconstruction tasks. :param in_channels: Number of input channels. If a list is provided, the model will have separate heads for each channel. :param device: Device to which the model should be moved. If None, the model will be created on the default device. :param pretrained: If True, the model will be initialized with pretrained weights. """ def __init__( self, in_channels=[1, 2, 3], device=None, pretrained=True, ): super(RAM, self).__init__() nc = [64, 128, 256, 512] # number of channels in the network self.in_channels = in_channels self.fact_realign = torch.nn.Parameter(torch.tensor([1.0], device=device)) self.separate_head = isinstance(in_channels, list) if isinstance(in_channels, list): in_channels_first = [] for i in range(len(in_channels)): in_channels_first.append(in_channels[i] + 2) # check if in_channels is a list self.m_head = InHead(in_channels_first, nc[0]) self.m_down1 = BaseEncBlock(nc[0], nc[0], img_channels=in_channels, decode_upscale=1) self.m_down2 = BaseEncBlock(nc[1], nc[1], img_channels=in_channels, decode_upscale=2) self.m_down3 = BaseEncBlock(nc[2], nc[2], img_channels=in_channels, decode_upscale=4) self.m_body = BaseEncBlock(nc[3], nc[3], img_channels=in_channels, decode_upscale=8) self.m_up3 = BaseEncBlock(nc[2], nc[2], img_channels=in_channels, decode_upscale=4) self.m_up2 = BaseEncBlock(nc[1], nc[1], img_channels=in_channels, decode_upscale=2) self.m_up1 = BaseEncBlock(nc[0], nc[0], img_channels=in_channels, decode_upscale=1) self.pool1 = downsample_strideconv(nc[0], nc[1], bias=False, mode="2") self.pool2 = downsample_strideconv(nc[1], nc[2], bias=False, mode="2") self.pool3 = downsample_strideconv(nc[2], nc[3], bias=False, mode="2") self.up3 = upsample_convtranspose(nc[3], nc[2], bias=False, mode="2") self.up2 = upsample_convtranspose(nc[2], nc[1], bias=False, mode="2") self.up1 = upsample_convtranspose(nc[1], nc[0], bias=False, mode="2") self.m_tail = OutTail(nc[0], in_channels) # load pretrained weights from hugging face if pretrained: self.load_state_dict( torch.load(hf_hub_download(repo_id="mterris/ram", filename="ram.pth.tar"), map_location=device)) if device is not None: self.to(device) def constant2map(self, value, x): r""" Converts a constant value to a map of the same size as the input tensor x. :params float value: constant value :params torch.Tensor x: input tensor """ if isinstance(value, torch.Tensor): if value.ndim > 0: value_map = value.view(x.size(0), 1, 1, 1) value_map = value_map.expand(-1, 1, x.size(2), x.size(3)) else: value_map = torch.ones( (x.size(0), 1, x.size(2), x.size(3)), device=x.device ) * value[None, None, None, None].to(x.device) else: value_map = ( torch.ones((x.size(0), 1, x.size(2), x.size(3)), device=x.device) * value ) return value_map def base_conditioning(self, x, sigma, gamma): noise_level_map = self.constant2map(sigma, x) gamma_map = self.constant2map(gamma, x) return torch.cat((x, noise_level_map, gamma_map), 1) def realign_input(self, x, physics, y): r""" Realign the input x based on the measurements y and the physics model. Applies the proximity operator of the L2 norm with respect to the physics model. :params torch.Tensor x: Input tensor :params deepinv.physics.Physics physics: Physics model :params torch.Tensor y: Measurements """ if hasattr(physics, "factor"): f = physics.factor elif hasattr(physics, "base") and hasattr(physics.base, "factor"): f = physics.base.factor elif hasattr(physics, "base") and hasattr(physics.base, "base") and hasattr(physics.base.base, "factor"): f = physics.base.base.factor else: f = 1.0 sigma = 1e-6 # default value if hasattr(physics.noise_model, 'sigma'): sigma = physics.noise_model.sigma if hasattr(physics, 'base') and hasattr(physics.base, 'noise_model') and hasattr(physics.base.noise_model, 'sigma'): sigma = physics.base.noise_model.sigma if hasattr(physics, 'base') and hasattr(physics.base, 'base') and hasattr(physics.base.base, 'noise_model') and hasattr( physics.base.base.noise_model, 'sigma'): sigma = physics.base.base.noise_model.sigma if isinstance(y, TensorList): num = (y[0].reshape(y[0].shape[0], -1).abs().mean(1)) else: num = (y.reshape(y.shape[0], -1).abs().mean(1)) snr = num / (sigma + 1e-4) # SNR equivariant gamma = 1 / (1e-4 + 1 / ( snr * f ** 2)) # TODO: check square-root / mean / check if we need to add a factor in front ? gamma = gamma[(...,) + (None,) * (x.dim() - 1)] model_input = physics.prox_l2(x, y, gamma=gamma * self.fact_realign) return model_input def forward_unet(self, x0, sigma=None, gamma=None, physics=None, y=None): r""" Forward pass of the UNet model. :params torch.Tensor x0: init image :params float sigma: Gaussian noise level :params float gamma: Poisson noise gain :params deepinv.physics.Physics physics: physics measurement operator :params torch.Tensor y: measurements """ img_channels = x0.shape[1] physics = MultiScaleLinearPhysics(physics, x0.shape[-3:], device=x0.device) if self.separate_head and img_channels not in self.in_channels: raise ValueError( f"Input image has {img_channels} channels, but the network only have heads for {self.in_channels} channels.") if y is not None: x0 = self.realign_input(x0, physics, y) x0 = self.base_conditioning(x0, sigma, gamma) x1 = self.m_head(x0) x1_ = self.m_down1(x1, physics=physics, y=y, img_channels=img_channels, scale=0) x2 = self.pool1(x1_) x3_ = self.m_down2(x2, physics=physics, y=y, img_channels=img_channels, scale=1) x3 = self.pool2(x3_) x4_ = self.m_down3(x3, physics=physics, y=y, img_channels=img_channels, scale=2) x4 = self.pool3(x4_) x = self.m_body(x4, physics=physics, y=y, img_channels=img_channels, scale=3) x = self.up3(x + x4) x = self.m_up3(x, physics=physics, y=y, img_channels=img_channels, scale=2) x = self.up2(x + x3) x = self.m_up2(x, physics=physics, y=y, img_channels=img_channels, scale=1) x = self.up1(x + x2) x = self.m_up1(x, physics=physics, y=y, img_channels=img_channels, scale=0) x = self.m_tail(x + x1, img_channels) return x def forward(self, y=None, physics=None): r""" Reconstructs a signal estimate from measurements y :param torch.tensor y: measurements :param deepinv.physics.Physics physics: forward operator """ if physics is None: physics = dinv.physics.Denoising(noise_model=dinv.physics.GaussianNoise(sigma=0.), device=y.device) x_temp = physics.A_adjoint(y) pad = (-x_temp.size(-2) % 8, -x_temp.size(-1) % 8) physics = Pad(physics, pad) x_in = physics.A_adjoint(y) sigma = physics.noise_model.sigma if hasattr(physics.noise_model, "sigma") else 1e-3 gamma = physics.noise_model.gain if hasattr(physics.noise_model, "gain") else 1e-3 out = self.forward_unet(x_in, sigma=sigma, gamma=gamma, physics=physics, y=y) out = physics.remove_pad(out) return out ### --------------- MODEL --------------- class BaseEncBlock(nn.Module): def __init__(self, in_channels, out_channels, bias=False, nb=4, img_channels=None, decode_upscale=None): super(BaseEncBlock, self).__init__() self.enc = nn.ModuleList( [ ResBlock( in_channels, out_channels, bias=bias, img_channels=img_channels, decode_upscale=decode_upscale, ) for _ in range(nb) ] ) def forward(self, x, physics=None, y=None, img_channels=None, scale=0): for i in range(len(self.enc)): x = self.enc[i](x, physics=physics, y=y, img_channels=img_channels, scale=scale) return x def krylov_embeddings(y, p, factor, v=None, N=4, x_init=None): r""" Efficient Krylov subspace embedding computation with parallel processing. :params torch.Tensor y: Input tensor. :params p: An object with A and A_adjoint methods (linear operator). :params float factor: Scaling factor. :params torch.Tensor v: Precomputed values to subtract from Krylov sequence. Defaults to None. :params int N: Number of Krylov iterations. Defaults to 4. :params torch.Tensor x_init: Initial guess. Defaults to None. """ if x_init is None: x = p.A_adjoint(y) else: x = x_init.clone() # Extract the first img_channels norm = factor ** 2 # Precompute normalization factor AtA = lambda u: p.A_adjoint(p.A(u)) * norm # Define the linear operator v = v if v is not None else torch.zeros_like(x) out = x.clone() # Compute Krylov basis x_k = x.clone() for i in range(N - 1): x_k = AtA(x_k) - v out = torch.cat([out, x_k], dim=1) return out class MeasCondBlock(nn.Module): r""" Measurement conditioning block for the RAM model. :param out_channels: Number of output channels. :param img_channels: Number of input channels. If a list is provided, the model will have separate heads for each channel. :param decode_upscale: Upscaling factor for the decoding convolution. :param N: Number of Krylov iterations. :param depth_encoding: Depth of the encoding convolution. :param c_mult: Multiplier for the number of channels. """ def __init__(self, out_channels=64, img_channels=None, decode_upscale=None, N=4, depth_encoding=1, c_mult=1): super(MeasCondBlock, self).__init__() self.separate_head = isinstance(img_channels, list) assert img_channels is not None, "decode_dimensions should be provided" assert decode_upscale is not None, "decode_upscale should be provided" self.N = N self.c_mult = c_mult self.relu_encoding = nn.ReLU(inplace=False) self.decoding_conv = Tails(out_channels, img_channels, depth=1, scale=1, bias=False, c_mult=self.c_mult) self.encoding_conv = Heads(img_channels, out_channels, depth=depth_encoding, scale=1, bias=False, c_mult=self.c_mult * N, c_add=N, relu_in=False, skip_in=True) self.gain = torch.nn.Parameter(torch.tensor([1.0]), requires_grad=True) self.gain_gradx = torch.nn.Parameter(torch.tensor([1e-2]), requires_grad=True) self.gain_grady = torch.nn.Parameter(torch.tensor([1e-2]), requires_grad=True) self.gain_pinvx = torch.nn.Parameter(torch.tensor([1e-2]), requires_grad=True) self.gain_pinvy = torch.nn.Parameter(torch.tensor([1e-2]), requires_grad=True) def forward(self, x, y, physics, img_channels=None, scale=1): physics.set_scale(scale) dec = self.decoding_conv(x, img_channels) factor = 2 ** (scale) meas_y = krylov_embeddings(y, physics, factor, N=self.N) meas_dec = krylov_embeddings(y, physics, factor, N=self.N, x_init=dec[:, :img_channels, ...]) for c in range(1, self.c_mult): meas_cur = krylov_embeddings(y, physics, factor, N=self.N, x_init=dec[:, img_channels * c:img_channels * (c + 1)]) meas_dec = torch.cat([meas_dec, meas_cur], dim=1) meas = torch.cat([meas_y, meas_dec], dim=1) cond = self.encoding_conv(meas) emb = self.relu_encoding(cond) return emb class ResBlock(nn.Module): r""" Convolutional residual block. :param in_channels: Number of input channels. :param out_channels: Number of output channels. :param kernel_size: Size of the convolution kernel. :param stride: Stride of the convolution. :param padding: Padding for the convolution. :param bias: Whether to use bias in the convolution. :param img_channels: Number of input channels. If a list is provided, the model will have separate heads for each channel. :param decode_upscale: Upscaling factor for the decoding convolution. :param head: Whether this is a head block. :param tail: Whether this is a tail block. :param N: Number of Krylov iterations. :param c_mult: Multiplier for the number of channels. :param depth_encoding: Depth of the encoding convolution. """ def __init__( self, in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=True, img_channels=None, decode_upscale=None, head=False, tail=False, N=2, c_mult=2, depth_encoding=2, ): super(ResBlock, self).__init__() if not head and not tail: assert in_channels == out_channels, "Only support in_channels==out_channels." self.separate_head = isinstance(img_channels, list) self.is_head = head self.is_tail = tail if self.is_head: self.head = InHead(img_channels, out_channels, input_layer=True) if not self.is_head and not self.is_tail: self.conv1 = conv( in_channels, out_channels, kernel_size, stride, padding, bias, "C", ) self.nl = nn.ReLU(inplace=True) self.conv2 = conv( out_channels, out_channels, kernel_size, stride, padding, bias, "C", ) self.gain = torch.nn.Parameter(torch.tensor([1.0]), requires_grad=True) self.PhysicsBlock = MeasCondBlock(out_channels=out_channels, c_mult=c_mult, img_channels=img_channels, decode_upscale=decode_upscale, N=N, depth_encoding=depth_encoding) def forward(self, x, physics=None, y=None, img_channels=None, scale=0): u = self.conv1(x) u = self.nl(u) u_2 = self.conv2(u) emb_grad = self.PhysicsBlock(u, y, physics, img_channels=img_channels, scale=scale) u_1 = self.gain * emb_grad return x + u_2 + u_1 class InHead(torch.nn.Module): def __init__(self, in_channels_list, out_channels, mode="", bias=False, input_layer=False): super(InHead, self).__init__() self.in_channels_list = in_channels_list self.input_layer = input_layer for i, in_channels in enumerate(in_channels_list): conv = AffineConv2d( in_channels=in_channels, out_channels=out_channels, bias=bias, mode=mode, kernel_size=3, stride=1, padding=1, padding_mode="zeros", ) setattr(self, f"conv{i}", conv) def forward(self, x): in_channels = x.size(1) - 1 if self.input_layer else x.size(1) # find index i = self.in_channels_list.index(in_channels) x = getattr(self, f"conv{i}")(x) return x class OutTail(torch.nn.Module): def __init__(self, in_channels, out_channels_list, mode="", bias=False): super(OutTail, self).__init__() self.in_channels = in_channels self.out_channels_list = out_channels_list for i, out_channels in enumerate(out_channels_list): conv = AffineConv2d( in_channels=in_channels, out_channels=out_channels, bias=bias, mode=mode, kernel_size=3, stride=1, padding=1, padding_mode="zeros", ) setattr(self, f"conv{i}", conv) def forward(self, x, out_channels): i = self.out_channels_list.index(out_channels) x = getattr(self, f"conv{i}")(x) return x class Heads(torch.nn.Module): def __init__(self, in_channels_list, out_channels, depth=2, scale=1, bias=True, mode="bilinear", c_mult=1, c_add=0, relu_in=False, skip_in=False): super(Heads, self).__init__() self.in_channels_list = [c * (c_mult + c_add) for c in in_channels_list] self.scale = scale self.mode = mode for i, in_channels in enumerate(self.in_channels_list): setattr(self, f"head{i}", HeadBlock(in_channels, out_channels, depth=depth, bias=bias, relu_in=relu_in, skip_in=skip_in)) if self.mode == "": self.nl = torch.nn.ReLU(inplace=False) if self.scale != 1: for i, in_channels in enumerate(in_channels_list): setattr(self, f"down{i}", downsample_strideconv(in_channels, in_channels, bias=False, mode=str(self.scale))) def forward(self, x): in_channels = x.size(1) i = self.in_channels_list.index(in_channels) if self.scale != 1: if self.mode == "bilinear": x = torch.nn.functional.interpolate(x, scale_factor=1 / self.scale, mode='bilinear', align_corners=False) else: x = getattr(self, f"down{i}")(x) x = self.nl(x) # find index x = getattr(self, f"head{i}")(x) return x class Tails(torch.nn.Module): def __init__(self, in_channels, out_channels_list, depth=2, scale=1, bias=True, mode="bilinear", c_mult=1, relu_in=False, skip_in=False): super(Tails, self).__init__() self.out_channels_list = out_channels_list self.scale = scale for i, out_channels in enumerate(out_channels_list): setattr(self, f"tail{i}", HeadBlock(in_channels, out_channels * c_mult, depth=depth, bias=bias, relu_in=relu_in, skip_in=skip_in)) self.mode = mode if self.mode == "": self.nl = torch.nn.ReLU(inplace=False) if self.scale != 1: for i, out_channels in enumerate(out_channels_list): setattr(self, f"up{i}", upsample_convtranspose(out_channels * c_mult, out_channels * c_mult, bias=bias, mode=str(self.scale))) def forward(self, x, out_channels): i = self.out_channels_list.index(out_channels) x = getattr(self, f"tail{i}")(x) # find index if self.scale != 1: if self.mode == "bilinear": x = torch.nn.functional.interpolate(x, scale_factor=self.scale, mode='bilinear', align_corners=False) else: x = getattr(self, f"up{i}")(x) return x class HeadBlock(torch.nn.Module): def __init__(self, in_channels, out_channels, kernel_size=3, bias=True, depth=2, relu_in=False, skip_in=False): super(HeadBlock, self).__init__() padding = kernel_size // 2 c = out_channels if depth < 2 else in_channels self.convin = torch.nn.Conv2d(in_channels, c, kernel_size, padding=padding, bias=bias) self.zero_conv_skip = torch.nn.Conv2d(in_channels, c, 1, bias=False) self.depth = depth self.nl_1 = torch.nn.ReLU(inplace=False) self.nl_2 = torch.nn.ReLU(inplace=False) self.relu_in = relu_in self.skip_in = skip_in for i in range(depth - 1): if i < depth - 2: c_in, c = in_channels, in_channels else: c_in, c = in_channels, out_channels setattr(self, f"conv1{i}", torch.nn.Conv2d(c_in, c_in, kernel_size, padding=padding, bias=bias)) setattr(self, f"conv2{i}", torch.nn.Conv2d(c_in, c, kernel_size, padding=padding, bias=bias)) setattr(self, f"skipconv{i}", torch.nn.Conv2d(c_in, c, 1, bias=False)) def forward(self, x): if self.skip_in and self.relu_in: x = self.nl_1(self.convin(x)) + self.zero_conv_skip(x) elif self.skip_in and not self.relu_in: x = self.convin(x) + self.zero_conv_skip(x) else: x = self.convin(x) for i in range(self.depth - 1): aux = getattr(self, f"conv1{i}")(x) aux = self.nl_2(aux) aux_0 = getattr(self, f"conv2{i}")(aux) aux_1 = getattr(self, f"skipconv{i}")(x) x = aux_0 + aux_1 return x # -------------------------------------------------------------------------------------- class AffineConv2d(nn.Conv2d): def __init__( self, in_channels, out_channels, kernel_size, mode="affine", bias=False, stride=1, padding=0, dilation=1, groups=1, padding_mode="circular", blind=True, ): if mode == "affine": # f(a*x + 1) = a*f(x) + 1 bias = False super().__init__( in_channels, out_channels, kernel_size, bias=bias, stride=stride, padding=padding, dilation=dilation, groups=groups, padding_mode=padding_mode, ) self.blind = blind self.mode = mode def affine(self, w): """returns new kernels that encode affine combinations""" return ( w.view(self.out_channels, -1).roll(1, 1).view(w.size()) - w + 1 / w[0, ...].numel() ) def forward(self, x): if self.mode != "affine": return super().forward(x) else: kernel = ( self.affine(self.weight) if self.blind else torch.cat( (self.affine(self.weight[:, :-1, :, :]), self.weight[:, -1:, :, :]), dim=1, ) ) padding = tuple( elt for elt in reversed(self.padding) for _ in range(2) ) # used to translate padding arg used by Conv module to the ones used by F.pad padding_mode = ( self.padding_mode if self.padding_mode != "zeros" else "constant" ) # used to translate padding_mode arg used by Conv module to the ones used by F.pad return F.conv2d( F.pad(x, padding, mode=padding_mode), kernel, stride=self.stride, dilation=self.dilation, groups=self.groups, ) """ Functional blocks below Parts of code borrowed from https://github.com/cszn/DPIR/tree/master/models https://github.com/xinntao/BasicSR """ from collections import OrderedDict import torch import torch.nn as nn """ # -------------------------------------------- # Advanced nn.Sequential # https://github.com/xinntao/BasicSR # -------------------------------------------- """ def sequential(*args): """Advanced nn.Sequential. Args: nn.Sequential, nn.Module Returns: nn.Sequential """ if len(args) == 1: if isinstance(args[0], OrderedDict): raise NotImplementedError("sequential does not support OrderedDict input.") return args[0] # No sequential is needed. modules = [] for module in args: if isinstance(module, nn.Sequential): for submodule in module.children(): modules.append(submodule) elif isinstance(module, nn.Module): modules.append(module) return nn.Sequential(*modules) def conv( in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1, bias=True, mode="CBR", ): L = [] for t in mode: if t == "C": L.append( nn.Conv2d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias, ) ) elif t == "T": L.append( nn.ConvTranspose2d( in_channels=in_channels, out_channels=out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias, ) ) elif t == "R": L.append(nn.ReLU(inplace=True)) else: raise NotImplementedError("Undefined type: ".format(t)) return sequential(*L) # -------------------------------------------- # convTranspose (+ relu) # -------------------------------------------- def upsample_convtranspose( in_channels=64, out_channels=3, padding=0, bias=True, mode="2R", ): assert len(mode) < 4 and mode[0] in [ "2", "3", "4", "8", ], "mode examples: 2, 2R, 2BR, 3, ..., 4BR." kernel_size = int(mode[0]) stride = int(mode[0]) mode = mode.replace(mode[0], "T") up1 = conv( in_channels, out_channels, kernel_size, stride, padding, bias, mode, ) return up1 def downsample_strideconv( in_channels=64, out_channels=64, padding=0, bias=True, mode="2R", ): assert len(mode) < 4 and mode[0] in [ "2", "3", "4", "8", ], "mode examples: 2, 2R, 2BR, 3, ..., 4BR." kernel_size = int(mode[0]) stride = int(mode[0]) mode = mode.replace(mode[0], "C") down1 = conv( in_channels, out_channels, kernel_size, stride, padding, bias, mode, ) return down1 class Upsampling(Downsampling): def A(self, x, **kwargs): return super().A_adjoint(x, **kwargs) def A_adjoint(self, y, **kwargs): return super().A(y, **kwargs) def prox_l2(self, z, y, gamma, **kwargs): return super().prox_l2(z, y, gamma, **kwargs) class MultiScalePhysics(Physics): def __init__(self, physics, img_shape, filter="sinc", scales=[2, 4, 8], device='cpu', **kwargs): super().__init__(noise_model=physics.noise_model, **kwargs) self.base = physics self.scales = scales self.img_shape = img_shape self.Upsamplings = [Upsampling(img_size=img_shape, filter=filter, factor=factor, device=device) for factor in scales] self.scale = 0 def set_scale(self, scale): if scale is not None: self.scale = scale def A(self, x, scale=None, **kwargs): self.set_scale(scale) if self.scale == 0: return self.base.A(x, **kwargs) else: return self.base.A(self.Upsamplings[self.scale - 1].A(x), **kwargs) def downsample(self, x, scale=None): self.set_scale(scale) if self.scale == 0: return x else: return self.Upsamplings[self.scale - 1].A_adjoint(x) def upsample(self, x, scale=None): self.set_scale(scale) if self.scale == 0: return x else: return self.Upsamplings[self.scale - 1].A(x) def update_parameters(self, **kwargs): self.base.update_parameters(**kwargs) class MultiScaleLinearPhysics(MultiScalePhysics, LinearPhysics): def __init__(self, physics, img_shape, filter="sinc", scales=[2, 4, 8], **kwargs): super().__init__(physics=physics, img_shape=img_shape, filter=filter, scales=scales, **kwargs) def A_adjoint(self, y, scale=None, **kwargs): self.set_scale(scale) y = self.base.A_adjoint(y, **kwargs) if self.scale == 0: return y else: return self.Upsamplings[self.scale - 1].A_adjoint(y) class Pad(LinearPhysics): def __init__(self, physics, pad): super().__init__(noise_model=physics.noise_model) self.base = physics self.pad = pad def A(self, x): return self.base.A(x[..., self.pad[0]:, self.pad[1]:]) def A_adjoint(self, y): y = self.base.A_adjoint(y) y = torch.nn.functional.pad(y, (self.pad[1], 0, self.pad[0], 0)) return y def remove_pad(self, x): return x[..., self.pad[0]:, self.pad[1]:] def update_parameters(self, **kwargs): self.base.update_parameters(**kwargs)