File size: 3,741 Bytes
12a4d59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import torch
from deepinv.physics.generator import PhysicsGenerator


class InpaintingMaskGenerator(PhysicsGenerator):

    def __init__(
        self,
        mask_shape: tuple,
        num_channels: int = 1,
        device: str = "cpu",
        dtype: type = torch.float32,
        block_size_ratio=0.1,
        num_blocks=5,
    ) -> None:
        kwargs = {
            "mask_shape": mask_shape,
            "block_size_ratio": block_size_ratio,
            "num_blocks": num_blocks,
        }
        if len(mask_shape) != 2:
            raise ValueError(
                "mask_shape must 2D. Add channels via num_channels parameter"
            )
        super().__init__(
            num_channels=num_channels,
            device=device,
            dtype=dtype,
            **kwargs,
        )

    def generate_mask(self, image_shape, block_size_ratio, num_blocks):
        # Create an all-ones tensor which will serve as the initial mask
        mask = torch.ones(image_shape)
        batch_size = mask.shape[0]

        # Calculate block size based on the image dimensions and block_size_ratio
        block_width = int(image_shape[-2] * block_size_ratio)
        block_height = int(image_shape[-1] * block_size_ratio)

        # Generate random coordinates for each block in each batch
        x_coords = torch.randint(
            0, image_shape[-1] - block_width, (batch_size, num_blocks)
        )
        y_coords = torch.randint(
            0, image_shape[-2] - block_height, (batch_size, num_blocks)
        )

        # Create grids of indices for the block dimensions
        x_range = torch.arange(block_width).view(1, 1, -1)
        y_range = torch.arange(block_height).view(1, 1, -1)

        # Expand ranges to match the batch and num_blocks dimensions
        x_indices = x_coords.unsqueeze(-1) + x_range
        y_indices = y_coords.unsqueeze(-1) + y_range

        # Expand and flatten the indices for advanced indexing
        x_indices = x_indices.unsqueeze(2).expand(-1, -1, block_height, -1).reshape(-1)
        y_indices = y_indices.unsqueeze(3).expand(-1, -1, -1, block_width).reshape(-1)

        # Create batch indices for advanced indexing
        batch_indices = (
            torch.arange(batch_size)
            .view(-1, 1, 1)
            .expand(-1, num_blocks, block_width * block_height)
            .reshape(-1)
        )
        channel_indices = (
            torch.arange(3)
            .view(1, 1, 1, -1)
            .expand(batch_size, num_blocks, block_width * block_height, -1)
            .reshape(-1)
        )

        # Apply the blocks using advanced indexing
        mask[batch_indices, :, y_indices, x_indices] = 0

        return mask

    def step(
        self, batch_size: int = 1, block_size_ratio: float = None, num_blocks=None
    ):
        r"""
        Generate a random motion blur PSF with parameters :math:`\sigma` and :math:`l`

        :param int batch_size: batch_size.
        :param float sigma: the standard deviation of the Gaussian Process
        :param float l: the length scale of the trajectory

        :return: dictionary with key **'filter'**: the generated PSF of shape `(batch_size, 1, psf_size[0], psf_size[1])`
        """

        # TODO: add randomness
        block_size_ratio = (
            self.block_size_ratio if block_size_ratio is None else block_size_ratio
        )
        num_blocks = self.num_blocks if num_blocks is None else num_blocks
        batch_shape = (
            batch_size,
            self.num_channels,
            self.mask_shape[-2],
            self.mask_shape[-1],
        )

        mask = self.generate_mask(batch_shape, block_size_ratio, num_blocks)

        return {"mask": mask.to(self.factory_kwargs["device"])}