|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Preprocessing for videos.""" |
|
|
|
from big_vision.pp import utils |
|
from big_vision.pp.registry import Registry |
|
|
|
import tensorflow as tf |
|
|
|
|
|
@Registry.register('preprocess_ops.video_decode') |
|
def video_decode(res): |
|
"""Preprocessing.""" |
|
|
|
def _pp_per_image(img): |
|
|
|
return tf.image.resize(tf.io.decode_jpeg(img), (res, res)) |
|
|
|
def _pp(data): |
|
images = data['episodic_images'] |
|
|
|
images = tf.map_fn(_pp_per_image, images, fn_output_signature=tf.float32) |
|
|
|
images = 2 * (images / 255.) - 1.0 |
|
data['image'] = images |
|
return data |
|
|
|
return _pp |
|
|
|
|
|
@Registry.register('preprocess_ops.video_ensure_shape') |
|
def video_ensure_shape(key, shape): |
|
"""Preprocessing.""" |
|
def _video_ensure_shape(data): |
|
data[key] = tf.ensure_shape(data[key], shape) |
|
return data |
|
|
|
return _video_ensure_shape |
|
|
|
|
|
@Registry.register('preprocess_ops.video_replicate_img') |
|
def video_replicate_img(replicas, num_frames): |
|
"""Ensure that for short videos, we have the correct number of frames. |
|
|
|
We replicate and select. |
|
|
|
Args: |
|
replicas: num_replicas before selection. Should be less than num_frames. |
|
num_frames: number of frames |
|
|
|
Returns: |
|
_replicate_img: preprocessing function |
|
""" |
|
|
|
def _replicate_img(data): |
|
|
|
image = data['image'] |
|
image = tf.tile(image, [replicas, 1, 1, 1]) |
|
data['image'] = image[:num_frames] |
|
return data |
|
|
|
return _replicate_img |
|
|
|
|
|
@Registry.register('preprocess_ops.video_choice') |
|
@utils.InKeyOutKey() |
|
def video_choice(empty_fallback=None): |
|
"""Randomly takes one entry out of a tensor after flattening.""" |
|
|
|
def _choice(x): |
|
x = tf.reshape(x, (-1,)) |
|
|
|
|
|
x0 = tf.zeros(1, x.dtype) if empty_fallback is None else [empty_fallback] |
|
x = tf.concat([x, x0], axis=0) |
|
|
|
num_choices = tf.maximum(tf.shape(x)[0] - 1, 1) |
|
return x[tf.random.uniform([], 0, num_choices, dtype=tf.int32)] |
|
|
|
return _choice |
|
|
|
|
|
@Registry.register('preprocess_ops.stack_images') |
|
def stack_images(inkeys=(), outkey='image'): |
|
|
|
def _pp(data): |
|
images = tf.stack([data[inkey] for inkey in inkeys]) |
|
data[outkey] = images |
|
return data |
|
|
|
return _pp |
|
|