Spaces:
Runtime error
Runtime error
# Copyright 2023 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Video classification input and model functions for serving/inference.""" | |
from typing import Mapping, Dict, Text | |
import tensorflow as tf, tf_keras | |
from official.vision.dataloaders import video_input | |
from official.vision.serving import export_base | |
from official.vision.tasks import video_classification | |
class VideoClassificationModule(export_base.ExportModule): | |
"""Video classification Module.""" | |
def _build_model(self): | |
input_params = self.params.task.train_data | |
self._num_frames = input_params.feature_shape[0] | |
self._stride = input_params.temporal_stride | |
self._min_resize = input_params.min_image_size | |
self._crop_size = input_params.feature_shape[1] | |
self._output_audio = input_params.output_audio | |
task = video_classification.VideoClassificationTask(self.params.task) | |
return task.build_model() | |
def _decode_tf_example(self, encoded_inputs: tf.Tensor): | |
sequence_description = { | |
# Each image is a string encoding JPEG. | |
video_input.IMAGE_KEY: | |
tf.io.FixedLenSequenceFeature((), tf.string), | |
} | |
if self._output_audio: | |
sequence_description[self._params.task.validation_data.audio_feature] = ( | |
tf.io.VarLenFeature(dtype=tf.float32)) | |
_, decoded_tensors = tf.io.parse_single_sequence_example( | |
encoded_inputs, {}, sequence_description) | |
for key, value in decoded_tensors.items(): | |
if isinstance(value, tf.SparseTensor): | |
decoded_tensors[key] = tf.sparse.to_dense(value) | |
return decoded_tensors | |
def _preprocess_image(self, image): | |
image = video_input.process_image( | |
image=image, | |
is_training=False, | |
num_frames=self._num_frames, | |
stride=self._stride, | |
num_test_clips=1, | |
min_resize=self._min_resize, | |
crop_size=self._crop_size, | |
num_crops=1) | |
image = tf.cast(image, tf.float32) # Use config. | |
features = {'image': image} | |
return features | |
def _preprocess_audio(self, audio): | |
features = {} | |
audio = tf.cast(audio, dtype=tf.float32) # Use config. | |
audio = video_input.preprocess_ops_3d.sample_sequence( | |
audio, 20, random=False, stride=1) | |
audio = tf.ensure_shape( | |
audio, self._params.task.validation_data.audio_feature_shape) | |
features['audio'] = audio | |
return features | |
def inference_from_tf_example( | |
self, encoded_inputs: tf.Tensor) -> Mapping[str, tf.Tensor]: | |
with tf.device('cpu:0'): | |
if self._output_audio: | |
inputs = tf.map_fn( | |
self._decode_tf_example, (encoded_inputs), | |
fn_output_signature={ | |
video_input.IMAGE_KEY: tf.string, | |
self._params.task.validation_data.audio_feature: tf.float32 | |
}) | |
return self.serve(inputs['image'], inputs['audio']) | |
else: | |
inputs = tf.map_fn( | |
self._decode_tf_example, (encoded_inputs), | |
fn_output_signature={ | |
video_input.IMAGE_KEY: tf.string, | |
}) | |
return self.serve(inputs[video_input.IMAGE_KEY], tf.zeros([1, 1])) | |
def inference_from_image_tensors( | |
self, input_frames: tf.Tensor) -> Mapping[str, tf.Tensor]: | |
return self.serve(input_frames, tf.zeros([1, 1])) | |
def inference_from_image_audio_tensors( | |
self, input_frames: tf.Tensor, | |
input_audio: tf.Tensor) -> Mapping[str, tf.Tensor]: | |
return self.serve(input_frames, input_audio) | |
def inference_from_image_bytes(self, inputs: tf.Tensor): | |
raise NotImplementedError( | |
'Video classification do not support image bytes input.') | |
def serve(self, input_frames: tf.Tensor, input_audio: tf.Tensor): | |
"""Cast image to float and run inference. | |
Args: | |
input_frames: uint8 Tensor of shape [batch_size, None, None, 3] | |
input_audio: float32 | |
Returns: | |
Tensor holding classification output logits. | |
""" | |
with tf.device('cpu:0'): | |
inputs = tf.map_fn( | |
self._preprocess_image, (input_frames), | |
fn_output_signature={ | |
'image': tf.float32, | |
}) | |
if self._output_audio: | |
inputs.update( | |
tf.map_fn( | |
self._preprocess_audio, (input_audio), | |
fn_output_signature={'audio': tf.float32})) | |
logits = self.inference_step(inputs) | |
if self.params.task.train_data.is_multilabel: | |
probs = tf.math.sigmoid(logits) | |
else: | |
probs = tf.nn.softmax(logits) | |
return {'logits': logits, 'probs': probs} | |
def get_inference_signatures(self, function_keys: Dict[Text, Text]): | |
"""Gets defined function signatures. | |
Args: | |
function_keys: A dictionary with keys as the function to create signature | |
for and values as the signature keys when returns. | |
Returns: | |
A dictionary with key as signature key and value as concrete functions | |
that can be used for tf.saved_model.save. | |
""" | |
signatures = {} | |
for key, def_name in function_keys.items(): | |
if key == 'image_tensor': | |
input_signature = tf.TensorSpec( | |
shape=[self._batch_size] + self._input_image_size + [3], | |
dtype=tf.uint8, | |
name='INPUT_FRAMES') | |
signatures[ | |
def_name] = self.inference_from_image_tensors.get_concrete_function( | |
input_signature) | |
elif key == 'frames_audio': | |
input_signature = [ | |
tf.TensorSpec( | |
shape=[self._batch_size] + self._input_image_size + [3], | |
dtype=tf.uint8, | |
name='INPUT_FRAMES'), | |
tf.TensorSpec( | |
shape=[self._batch_size] + | |
self.params.task.train_data.audio_feature_shape, | |
dtype=tf.float32, | |
name='INPUT_AUDIO') | |
] | |
signatures[ | |
def_name] = self.inference_from_image_audio_tensors.get_concrete_function( | |
input_signature) | |
elif key == 'serve_examples' or key == 'tf_example': | |
input_signature = tf.TensorSpec( | |
shape=[self._batch_size], dtype=tf.string) | |
signatures[ | |
def_name] = self.inference_from_tf_example.get_concrete_function( | |
input_signature) | |
else: | |
raise ValueError('Unrecognized `input_type`') | |
return signatures | |