pranavSIT's picture
added pali inference
74e8f2f
# Copyright 2024 Big Vision Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Preprocessing for videos."""
from big_vision.pp import utils
from big_vision.pp.registry import Registry
import tensorflow as tf
@Registry.register('preprocess_ops.video_decode')
def video_decode(res):
"""Preprocessing."""
def _pp_per_image(img):
# decode
return tf.image.resize(tf.io.decode_jpeg(img), (res, res))
def _pp(data):
images = data['episodic_images']
# resize
images = tf.map_fn(_pp_per_image, images, fn_output_signature=tf.float32)
# rescale
images = 2 * (images / 255.) - 1.0
data['image'] = images
return data
return _pp
@Registry.register('preprocess_ops.video_ensure_shape')
def video_ensure_shape(key, shape):
"""Preprocessing."""
def _video_ensure_shape(data):
data[key] = tf.ensure_shape(data[key], shape)
return data
return _video_ensure_shape
@Registry.register('preprocess_ops.video_replicate_img')
def video_replicate_img(replicas, num_frames):
"""Ensure that for short videos, we have the correct number of frames.
We replicate and select.
Args:
replicas: num_replicas before selection. Should be less than num_frames.
num_frames: number of frames
Returns:
_replicate_img: preprocessing function
"""
def _replicate_img(data):
# visual analogies + query
image = data['image']
image = tf.tile(image, [replicas, 1, 1, 1])
data['image'] = image[:num_frames]
return data
return _replicate_img
@Registry.register('preprocess_ops.video_choice')
@utils.InKeyOutKey()
def video_choice(empty_fallback=None):
"""Randomly takes one entry out of a tensor after flattening."""
def _choice(x):
x = tf.reshape(x, (-1,)) # Ensure it's a 1D array
# Append the fallback value so we gracefully handle empty cases.
x0 = tf.zeros(1, x.dtype) if empty_fallback is None else [empty_fallback]
x = tf.concat([x, x0], axis=0)
num_choices = tf.maximum(tf.shape(x)[0] - 1, 1) # Don't sample x0.
return x[tf.random.uniform([], 0, num_choices, dtype=tf.int32)]
return _choice
@Registry.register('preprocess_ops.stack_images')
def stack_images(inkeys=(), outkey='image'):
def _pp(data):
images = tf.stack([data[inkey] for inkey in inkeys])
data[outkey] = images
return data
return _pp