import torch import torchvision from deepinv.physics.blur import rotate from deepinv.physics.generator import PSFGenerator def gaussian_blur_padded(sigma=(1, 1), angle=0, filt_size=None): r""" Padded gaussian blur filter. Defined as .. math:: \begin{equation*} G(x, y) = \frac{1}{2\pi\sigma_x\sigma_y} \exp{\left(-\frac{x'^2}{2\sigma_x^2} - \frac{y'^2}{2\sigma_y^2}\right)} \end{equation*} where :math:`x'` and :math:`y'` are the rotated coordinates obtained by rotating $(x, y)$ around the origin by an angle :math:`\theta`: .. math:: \begin{align*} x' &= x \cos(\theta) - y \sin(\theta) \\ y' &= x \sin(\theta) + y \cos(\theta) \end{align*} with :math:`\sigma_x` and :math:`\sigma_y` the standard deviations along the :math:`x'` and :math:`y'` axes. :param float, tuple[float] sigma: standard deviation of the gaussian filter. If sigma is a float the filter is isotropic, whereas if sigma is a tuple of floats (sigma_x, sigma_y) the filter is anisotropic. :param float angle: rotation angle of the filter in degrees (only useful for anisotropic filters) """ if isinstance(sigma, (int, float)): sigma = (sigma, sigma) device = "cpu" elif isinstance(sigma, torch.Tensor): device = sigma.device s = max(sigma) c = int(s / 0.3 + 1) k_size = 2 * c + 1 delta = torch.arange(k_size).to(device) x, y = torch.meshgrid(delta, delta, indexing="ij") x = x - c y = y - c filt = (x / sigma[0]).pow(2) filt += (y / sigma[1]).pow(2) filt = torch.exp(-filt / 2.0) filt = ( rotate( filt.unsqueeze(0).unsqueeze(0), angle, interpolation=torchvision.transforms.InterpolationMode.BILINEAR, ) .squeeze(0) .squeeze(0) ) filt = filt / filt.flatten().sum() filt = filt.unsqueeze(0).unsqueeze(0) if filt_size is not None: filt = torch.nn.functional.pad( filt, ( (filt_size[0] - filt.shape[-2]) // 2, (filt_size[0] - filt.shape[-2] + 1) // 2, (filt_size[1] - filt.shape[-1]) // 2, (filt_size[1] - filt.shape[-1] + 1) // 2, ), ) return filt class GaussianBlurGenerator(PSFGenerator): def __init__( self, psf_size: tuple, num_channels: int = 1, device: str = "cpu", dtype: type = torch.float32, l: float = 0.3, sigma: float = 0.25, sigma_min: float = 0.01, sigma_max: float = 4.0, ) -> None: kwargs = { "l": l, "sigma": sigma, "sigma_min": sigma_min, "sigma_max": sigma_max, } if len(psf_size) != 2: raise ValueError( "psf_size must 2D. Add channels via num_channels parameter" ) super().__init__( psf_size=psf_size, num_channels=num_channels, device=device, dtype=dtype, **kwargs, ) def step(self, batch_size: int = 1, sigma: float = None, **kwargs): r""" Generate a random motion blur PSF with parameters :math:`\sigma` and :math:`l` :param int batch_size: batch_size. :param float sigma: the standard deviation of the Gaussian Process :param float l: the length scale of the trajectory :return: dictionary with key **'filter'**: the generated PSF of shape `(batch_size, 1, psf_size[0], psf_size[1])` """ sigmas = [ self.sigma_min + torch.rand(2, **self.factory_kwargs) * (self.sigma_max - self.sigma_min) for batch in range(batch_size) ] angles = [ (torch.rand(1, **self.factory_kwargs) * 180.0).item() for batch in range(batch_size) ] kernels = [ gaussian_blur_padded(sigma, angle, filt_size=self.psf_size) for sigma, angle in zip(sigmas, angles) ] kernel = torch.cat(kernels, dim=0) return { "filter": kernel.expand( -1, self.num_channels, -1, -1, ) }