pranavSIT's picture
added pali inference
74e8f2f
# Copyright 2022 Big Vision Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Preprocessing ops."""
from big_vision.pp import utils
from big_vision.pp.registry import Registry
import numpy as np
import tensorflow as tf
@Registry.register("preprocess_ops.rgb_to_grayscale_to_rgb")
@utils.InKeyOutKey(indefault="image", outdefault="image")
def get_rgb_to_grayscale_to_rgb():
def _rgb_to_grayscale_to_rgb(image):
return tf.image.grayscale_to_rgb(tf.image.rgb_to_grayscale(image))
return _rgb_to_grayscale_to_rgb
@Registry.register("preprocess_ops.nyu_eval_crop")
def get_nyu_eval_crop():
"""Crops labels and image to valid eval area."""
# crop_h = slice(45, 471)
# crop_w = slice(41, 601)
crop_h_start = 54
crop_h_size = 426
crop_w_start = 41
crop_w_size = 560
def _pp(data):
tf.debugging.assert_equal(tf.shape(data["labels"]), (480, 640, 1))
tf.debugging.assert_equal(tf.shape(data["image"]), (480, 640, 3))
data["labels"] = tf.slice(data["labels"],
[crop_h_start, crop_w_start, 0],
[crop_h_size, crop_w_size, -1])
data["image"] = tf.slice(data["image"],
[crop_h_start, crop_w_start, 0],
[crop_h_size, crop_w_size, -1])
return data
return _pp
@Registry.register("preprocess_ops.nyu_depth")
@utils.InKeyOutKey(indefault="depth", outdefault="labels")
def get_nyu_depth():
"""Preprocesses NYU depth data."""
def _pp(depth):
return tf.expand_dims(tf.cast(depth, tf.float32), -1)
return _pp
@Registry.register("preprocess_ops.coco_panoptic")
def get_coco_panoptic_pp():
"""COCO-panoptic: produces a mask with labels and a mask with instance ids.
Instance channel will have values between 1 and N, and -1 for non-annotated
pixels.
Returns:
COCO panoptic preprocessign op.
"""
def _coco_panoptic(data):
instance_ids = tf.cast(data["panoptic_objects"]["id"], tf.int32)
instance_labels = tf.cast(data["panoptic_objects"]["label"], tf.int32)
# Convert image with ids split in 3 channels into a an integer id.
id_mask = tf.einsum(
"hwc,c->hw",
tf.cast(data["panoptic_image"], tf.int32),
tf.constant([1, 256, 256**2], tf.int32))
# Broadcast into N boolean masks one per instance_id.
n_masks = tf.cast(
id_mask[:, :, None] == instance_ids[None, None, :], tf.int32)
# Merge into a semantic and an instance id mask.
# Note: pixels which do not belong to any mask, will have value=-1
# which creates an empty one_hot masks.
# Number instances starting at 1 (0 is treated specially by make_canonical).
instance_idx = tf.range(tf.shape(instance_ids)[-1])
instances = tf.einsum("hwc,c->hw", n_masks, instance_idx + 1)
semantics = tf.einsum("hwc,c->hw", n_masks, instance_labels + 1)
data["instances"] = instances[:, :, None]
data["semantics"] = semantics[:, :, None]
return data
return _coco_panoptic
@Registry.register("preprocess_ops.make_canonical")
@utils.InKeyOutKey(indefault="labels", outdefault="labels")
def get_make_canonical(random=False, main_sort_axis="y"):
"""Makes id mask ordered from left to right based on the center of mass."""
# By convention, instances are in the last channel.
def _make_canonical(image):
"""Op."""
instimg = image[..., -1]
# Compute binary instance masks. Note, we do not touch 0 and neg. ids.
ids = tf.unique(tf.reshape(instimg, [-1])).y
ids = ids[ids > 0]
n_masks = tf.cast(
instimg[None, :, :] == ids[:, None, None], tf.int32)
if not random:
f = lambda x: tf.reduce_mean(tf.cast(tf.where(x), tf.float32), axis=0)
centers = tf.map_fn(f, tf.cast(n_masks, tf.int64), dtype=tf.float32)
centers = tf.reshape(centers, (tf.shape(centers)[0], 2))
major = {"y": 0, "x": 1}[main_sort_axis]
perm = tf.argsort(
centers[:, 1 - major] +
tf.cast(tf.shape(instimg)[major], tf.float32) * centers[:, major])
n_masks = tf.gather(n_masks, perm)
else:
n_masks = tf.random.shuffle(n_masks)
idx = tf.range(tf.shape(ids)[0])
can_mask = tf.einsum("chw,c->hw", n_masks, idx + 2) - 1
# Now, all 0 and neg. ids have collapsed to -1. Thus, we recover 0 id from
# the original mask.
can_mask = tf.where(instimg == 0, 0, can_mask)
return tf.concat([image[..., :-1], can_mask[..., None]], axis=-1)
return _make_canonical
@Registry.register("preprocess_ops.inception_box")
def get_inception_box(
*, area=(0.05, 1.0), aspect=(0.75, 1.33), min_obj_cover=0.0,
outkey="box", inkey="image"):
"""Creates an inception style bounding box which can be used to crop."""
def _inception_box(data):
_, _, box = tf.image.sample_distorted_bounding_box(
tf.shape(data[inkey]),
area_range=area,
aspect_ratio_range=aspect,
min_object_covered=min_obj_cover,
bounding_boxes=(data["objects"]["bbox"][None, :, :]
if min_obj_cover else tf.zeros([0, 0, 4])),
use_image_if_no_bounding_boxes=True)
# bbox is [[[y0,x0,y1,x1]]]
data[outkey] = (box[0, 0, :2], box[0, 0, 2:] - box[0, 0, :2])
return data
return _inception_box
@Registry.register("preprocess_ops.crop_box")
@utils.InKeyOutKey(with_data=True)
def get_crop_box(*, boxkey="box"):
"""Crops an image according to bounding box in `boxkey`."""
def _crop_box(image, data):
shape = tf.shape(image)[:-1]
begin, size = data[boxkey]
begin = tf.cast(begin * tf.cast(shape, tf.float32), tf.int32)
size = tf.cast(size * tf.cast(shape, tf.float32), tf.int32)
begin = tf.concat([begin, tf.constant((0,))], axis=0)
size = tf.concat([size, tf.constant((-1,))], axis=0)
crop = tf.slice(image, begin, size)
# Unfortunately, the above operation loses the depth-dimension. So we need
# to restore it the manual way.
crop.set_shape([None, None, image.shape[-1]])
return crop
return _crop_box
@Registry.register("preprocess_ops.randu")
def get_randu(key):
"""Creates a random uniform float [0, 1) in `key`."""
def _randu(data):
data[key] = tf.random.uniform([])
return data
return _randu
@Registry.register("preprocess_ops.det_fliplr")
@utils.InKeyOutKey(with_data=True)
def get_det_fliplr(*, randkey="fliplr"):
"""Flips an image horizontally based on `randkey`."""
# NOTE: we could unify this with regular flip when randkey=None.
def _det_fliplr(orig_image, data):
flip_image = tf.image.flip_left_right(orig_image)
flip = tf.cast(data[randkey] > 0.5, orig_image.dtype)
return flip_image * flip + orig_image * (1 - flip)
return _det_fliplr
@Registry.register("preprocess_ops.strong_hash")
@utils.InKeyOutKey(indefault="tfds_id", outdefault="tfds_id")
def get_strong_hash():
"""Preprocessing that hashes a string."""
def _strong_hash(string):
return tf.strings.to_hash_bucket_strong(
string,
np.iinfo(int).max, [3714561454027272724, 8800639020734831960])
return _strong_hash