Spaces:
Sleeping
Sleeping
import torch | |
from deepinv.physics.generator import PhysicsGenerator | |
class InpaintingMaskGenerator(PhysicsGenerator): | |
def __init__( | |
self, | |
mask_shape: tuple, | |
num_channels: int = 1, | |
device: str = "cpu", | |
dtype: type = torch.float32, | |
block_size_ratio=0.1, | |
num_blocks=5, | |
) -> None: | |
kwargs = { | |
"mask_shape": mask_shape, | |
"block_size_ratio": block_size_ratio, | |
"num_blocks": num_blocks, | |
} | |
if len(mask_shape) != 2: | |
raise ValueError( | |
"mask_shape must 2D. Add channels via num_channels parameter" | |
) | |
super().__init__( | |
num_channels=num_channels, | |
device=device, | |
dtype=dtype, | |
**kwargs, | |
) | |
def generate_mask(self, image_shape, block_size_ratio, num_blocks): | |
# Create an all-ones tensor which will serve as the initial mask | |
mask = torch.ones(image_shape) | |
batch_size = mask.shape[0] | |
# Calculate block size based on the image dimensions and block_size_ratio | |
block_width = int(image_shape[-2] * block_size_ratio) | |
block_height = int(image_shape[-1] * block_size_ratio) | |
# Generate random coordinates for each block in each batch | |
x_coords = torch.randint( | |
0, image_shape[-1] - block_width, (batch_size, num_blocks) | |
) | |
y_coords = torch.randint( | |
0, image_shape[-2] - block_height, (batch_size, num_blocks) | |
) | |
# Create grids of indices for the block dimensions | |
x_range = torch.arange(block_width).view(1, 1, -1) | |
y_range = torch.arange(block_height).view(1, 1, -1) | |
# Expand ranges to match the batch and num_blocks dimensions | |
x_indices = x_coords.unsqueeze(-1) + x_range | |
y_indices = y_coords.unsqueeze(-1) + y_range | |
# Expand and flatten the indices for advanced indexing | |
x_indices = x_indices.unsqueeze(2).expand(-1, -1, block_height, -1).reshape(-1) | |
y_indices = y_indices.unsqueeze(3).expand(-1, -1, -1, block_width).reshape(-1) | |
# Create batch indices for advanced indexing | |
batch_indices = ( | |
torch.arange(batch_size) | |
.view(-1, 1, 1) | |
.expand(-1, num_blocks, block_width * block_height) | |
.reshape(-1) | |
) | |
channel_indices = ( | |
torch.arange(3) | |
.view(1, 1, 1, -1) | |
.expand(batch_size, num_blocks, block_width * block_height, -1) | |
.reshape(-1) | |
) | |
# Apply the blocks using advanced indexing | |
mask[batch_indices, :, y_indices, x_indices] = 0 | |
return mask | |
def step( | |
self, batch_size: int = 1, block_size_ratio: float = None, num_blocks=None | |
): | |
r""" | |
Generate a random motion blur PSF with parameters :math:`\sigma` and :math:`l` | |
:param int batch_size: batch_size. | |
:param float sigma: the standard deviation of the Gaussian Process | |
:param float l: the length scale of the trajectory | |
:return: dictionary with key **'filter'**: the generated PSF of shape `(batch_size, 1, psf_size[0], psf_size[1])` | |
""" | |
# TODO: add randomness | |
block_size_ratio = ( | |
self.block_size_ratio if block_size_ratio is None else block_size_ratio | |
) | |
num_blocks = self.num_blocks if num_blocks is None else num_blocks | |
batch_shape = ( | |
batch_size, | |
self.num_channels, | |
self.mask_shape[-2], | |
self.mask_shape[-1], | |
) | |
mask = self.generate_mask(batch_shape, block_size_ratio, num_blocks) | |
return {"mask": mask.to(self.factory_kwargs["device"])} | |