|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Preprocessing ops.""" |
|
from big_vision.pp import utils |
|
from big_vision.pp.registry import Registry |
|
import numpy as np |
|
import tensorflow as tf |
|
|
|
|
|
@Registry.register("preprocess_ops.rgb_to_grayscale_to_rgb") |
|
@utils.InKeyOutKey(indefault="image", outdefault="image") |
|
def get_rgb_to_grayscale_to_rgb(): |
|
def _rgb_to_grayscale_to_rgb(image): |
|
return tf.image.grayscale_to_rgb(tf.image.rgb_to_grayscale(image)) |
|
return _rgb_to_grayscale_to_rgb |
|
|
|
|
|
@Registry.register("preprocess_ops.nyu_eval_crop") |
|
def get_nyu_eval_crop(): |
|
"""Crops labels and image to valid eval area.""" |
|
|
|
|
|
crop_h_start = 54 |
|
crop_h_size = 426 |
|
crop_w_start = 41 |
|
crop_w_size = 560 |
|
|
|
def _pp(data): |
|
tf.debugging.assert_equal(tf.shape(data["labels"]), (480, 640, 1)) |
|
tf.debugging.assert_equal(tf.shape(data["image"]), (480, 640, 3)) |
|
data["labels"] = tf.slice(data["labels"], |
|
[crop_h_start, crop_w_start, 0], |
|
[crop_h_size, crop_w_size, -1]) |
|
data["image"] = tf.slice(data["image"], |
|
[crop_h_start, crop_w_start, 0], |
|
[crop_h_size, crop_w_size, -1]) |
|
return data |
|
return _pp |
|
|
|
|
|
@Registry.register("preprocess_ops.nyu_depth") |
|
@utils.InKeyOutKey(indefault="depth", outdefault="labels") |
|
def get_nyu_depth(): |
|
"""Preprocesses NYU depth data.""" |
|
def _pp(depth): |
|
return tf.expand_dims(tf.cast(depth, tf.float32), -1) |
|
return _pp |
|
|
|
|
|
@Registry.register("preprocess_ops.coco_panoptic") |
|
def get_coco_panoptic_pp(): |
|
"""COCO-panoptic: produces a mask with labels and a mask with instance ids. |
|
|
|
Instance channel will have values between 1 and N, and -1 for non-annotated |
|
pixels. |
|
|
|
Returns: |
|
COCO panoptic preprocessign op. |
|
""" |
|
def _coco_panoptic(data): |
|
instance_ids = tf.cast(data["panoptic_objects"]["id"], tf.int32) |
|
instance_labels = tf.cast(data["panoptic_objects"]["label"], tf.int32) |
|
|
|
|
|
id_mask = tf.einsum( |
|
"hwc,c->hw", |
|
tf.cast(data["panoptic_image"], tf.int32), |
|
tf.constant([1, 256, 256**2], tf.int32)) |
|
|
|
|
|
n_masks = tf.cast( |
|
id_mask[:, :, None] == instance_ids[None, None, :], tf.int32) |
|
|
|
|
|
|
|
|
|
|
|
instance_idx = tf.range(tf.shape(instance_ids)[-1]) |
|
instances = tf.einsum("hwc,c->hw", n_masks, instance_idx + 1) |
|
semantics = tf.einsum("hwc,c->hw", n_masks, instance_labels + 1) |
|
|
|
data["instances"] = instances[:, :, None] |
|
data["semantics"] = semantics[:, :, None] |
|
return data |
|
|
|
return _coco_panoptic |
|
|
|
|
|
@Registry.register("preprocess_ops.make_canonical") |
|
@utils.InKeyOutKey(indefault="labels", outdefault="labels") |
|
def get_make_canonical(random=False, main_sort_axis="y"): |
|
"""Makes id mask ordered from left to right based on the center of mass.""" |
|
|
|
def _make_canonical(image): |
|
"""Op.""" |
|
instimg = image[..., -1] |
|
|
|
|
|
ids = tf.unique(tf.reshape(instimg, [-1])).y |
|
ids = ids[ids > 0] |
|
n_masks = tf.cast( |
|
instimg[None, :, :] == ids[:, None, None], tf.int32) |
|
|
|
if not random: |
|
f = lambda x: tf.reduce_mean(tf.cast(tf.where(x), tf.float32), axis=0) |
|
centers = tf.map_fn(f, tf.cast(n_masks, tf.int64), dtype=tf.float32) |
|
centers = tf.reshape(centers, (tf.shape(centers)[0], 2)) |
|
major = {"y": 0, "x": 1}[main_sort_axis] |
|
perm = tf.argsort( |
|
centers[:, 1 - major] + |
|
tf.cast(tf.shape(instimg)[major], tf.float32) * centers[:, major]) |
|
n_masks = tf.gather(n_masks, perm) |
|
else: |
|
n_masks = tf.random.shuffle(n_masks) |
|
|
|
idx = tf.range(tf.shape(ids)[0]) |
|
can_mask = tf.einsum("chw,c->hw", n_masks, idx + 2) - 1 |
|
|
|
|
|
can_mask = tf.where(instimg == 0, 0, can_mask) |
|
return tf.concat([image[..., :-1], can_mask[..., None]], axis=-1) |
|
|
|
return _make_canonical |
|
|
|
|
|
@Registry.register("preprocess_ops.inception_box") |
|
def get_inception_box( |
|
*, area=(0.05, 1.0), aspect=(0.75, 1.33), min_obj_cover=0.0, |
|
outkey="box", inkey="image"): |
|
"""Creates an inception style bounding box which can be used to crop.""" |
|
def _inception_box(data): |
|
_, _, box = tf.image.sample_distorted_bounding_box( |
|
tf.shape(data[inkey]), |
|
area_range=area, |
|
aspect_ratio_range=aspect, |
|
min_object_covered=min_obj_cover, |
|
bounding_boxes=(data["objects"]["bbox"][None, :, :] |
|
if min_obj_cover else tf.zeros([0, 0, 4])), |
|
use_image_if_no_bounding_boxes=True) |
|
|
|
data[outkey] = (box[0, 0, :2], box[0, 0, 2:] - box[0, 0, :2]) |
|
return data |
|
return _inception_box |
|
|
|
|
|
@Registry.register("preprocess_ops.crop_box") |
|
@utils.InKeyOutKey(with_data=True) |
|
def get_crop_box(*, boxkey="box"): |
|
"""Crops an image according to bounding box in `boxkey`.""" |
|
def _crop_box(image, data): |
|
shape = tf.shape(image)[:-1] |
|
begin, size = data[boxkey] |
|
begin = tf.cast(begin * tf.cast(shape, tf.float32), tf.int32) |
|
size = tf.cast(size * tf.cast(shape, tf.float32), tf.int32) |
|
begin = tf.concat([begin, tf.constant((0,))], axis=0) |
|
size = tf.concat([size, tf.constant((-1,))], axis=0) |
|
crop = tf.slice(image, begin, size) |
|
|
|
|
|
crop.set_shape([None, None, image.shape[-1]]) |
|
return crop |
|
return _crop_box |
|
|
|
|
|
@Registry.register("preprocess_ops.randu") |
|
def get_randu(key): |
|
"""Creates a random uniform float [0, 1) in `key`.""" |
|
def _randu(data): |
|
data[key] = tf.random.uniform([]) |
|
return data |
|
return _randu |
|
|
|
|
|
@Registry.register("preprocess_ops.det_fliplr") |
|
@utils.InKeyOutKey(with_data=True) |
|
def get_det_fliplr(*, randkey="fliplr"): |
|
"""Flips an image horizontally based on `randkey`.""" |
|
|
|
def _det_fliplr(orig_image, data): |
|
flip_image = tf.image.flip_left_right(orig_image) |
|
flip = tf.cast(data[randkey] > 0.5, orig_image.dtype) |
|
return flip_image * flip + orig_image * (1 - flip) |
|
return _det_fliplr |
|
|
|
|
|
@Registry.register("preprocess_ops.strong_hash") |
|
@utils.InKeyOutKey(indefault="tfds_id", outdefault="tfds_id") |
|
def get_strong_hash(): |
|
"""Preprocessing that hashes a string.""" |
|
def _strong_hash(string): |
|
return tf.strings.to_hash_bucket_strong( |
|
string, |
|
np.iinfo(int).max, [3714561454027272724, 8800639020734831960]) |
|
return _strong_hash |
|
|