|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Operations for image patches.""" |
|
|
|
from __future__ import absolute_import |
|
from __future__ import division |
|
from __future__ import print_function |
|
|
|
import tensorflow.compat.v1 as tf |
|
|
|
|
|
def get_patch_mask(y, x, patch_size, image_shape): |
|
"""Creates a 2D mask array for a square patch of a given size and location. |
|
|
|
The mask is created with its center at the y and x coordinates, which must be |
|
within the image. While the mask center must be within the image, the mask |
|
itself can be partially outside of it. If patch_size is an even number, then |
|
the mask is created with lower-valued coordinates first (top and left). |
|
|
|
Args: |
|
y: An integer or scalar int32 tensor. The vertical coordinate of the |
|
patch mask center. Must be within the range [0, image_height). |
|
x: An integer or scalar int32 tensor. The horizontal coordinate of the |
|
patch mask center. Must be within the range [0, image_width). |
|
patch_size: An integer or scalar int32 tensor. The square size of the |
|
patch mask. Must be at least 1. |
|
image_shape: A list or 1D int32 tensor representing the shape of the image |
|
to which the mask will correspond, with the first two values being image |
|
height and width. For example, [image_height, image_width] or |
|
[image_height, image_width, image_channels]. |
|
|
|
Returns: |
|
Boolean mask tensor of shape [image_height, image_width] with True values |
|
for the patch. |
|
|
|
Raises: |
|
tf.errors.InvalidArgumentError: if x is not in the range [0, image_width), y |
|
is not in the range [0, image_height), or patch_size is not at least 1. |
|
""" |
|
image_hw = image_shape[:2] |
|
mask_center_yx = tf.stack([y, x]) |
|
with tf.control_dependencies([ |
|
tf.debugging.assert_greater_equal( |
|
patch_size, 1, |
|
message='Patch size must be >= 1'), |
|
tf.debugging.assert_greater_equal( |
|
mask_center_yx, 0, |
|
message='Patch center (y, x) must be >= (0, 0)'), |
|
tf.debugging.assert_less( |
|
mask_center_yx, image_hw, |
|
message='Patch center (y, x) must be < image (h, w)') |
|
]): |
|
mask_center_yx = tf.identity(mask_center_yx) |
|
|
|
half_patch_size = tf.cast(patch_size, dtype=tf.float32) / 2 |
|
start_yx = mask_center_yx - tf.cast(tf.floor(half_patch_size), dtype=tf.int32) |
|
end_yx = mask_center_yx + tf.cast(tf.ceil(half_patch_size), dtype=tf.int32) |
|
|
|
start_yx = tf.maximum(start_yx, 0) |
|
end_yx = tf.minimum(end_yx, image_hw) |
|
|
|
start_y = start_yx[0] |
|
start_x = start_yx[1] |
|
end_y = end_yx[0] |
|
end_x = end_yx[1] |
|
|
|
lower_pad = image_hw[0] - end_y |
|
upper_pad = start_y |
|
left_pad = start_x |
|
right_pad = image_hw[1] - end_x |
|
mask = tf.ones([end_y - start_y, end_x - start_x], dtype=tf.bool) |
|
return tf.pad(mask, [[upper_pad, lower_pad], [left_pad, right_pad]]) |
|
|