Spaces:
Runtime error
Runtime error
# Copyright 2023 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Build Movinet for video classification. | |
Reference: https://arxiv.org/pdf/2103.11511.pdf | |
""" | |
from typing import Any, Dict, Mapping, Optional, Sequence, Tuple, Union | |
from absl import logging | |
import tensorflow as tf, tf_keras | |
from official.projects.movinet.configs import movinet as cfg | |
from official.projects.movinet.modeling import movinet_layers_a2_modified | |
from official.vision.modeling import backbones | |
from official.vision.modeling import factory_3d as model_factory | |
class MovinetClassifier(tf_keras.Model): | |
"""A video classification class builder.""" | |
def __init__( | |
self, | |
backbone: tf_keras.Model, | |
num_classes: int, | |
encoder_dim: int = 768, | |
input_specs: Optional[Mapping[str, tf_keras.layers.InputSpec]] = None, | |
activation: str = 'swish', | |
dropout_rate: float = 0.1, | |
kernel_initializer: str = 'HeNormal', | |
kernel_regularizer: Optional[tf_keras.regularizers.Regularizer] = None, | |
bias_regularizer: Optional[tf_keras.regularizers.Regularizer] = None, | |
output_states: bool = False, | |
**kwargs): | |
"""Movinet initialization function. | |
Args: | |
backbone: A 3d backbone network. | |
num_classes: Number of classes in classification task. | |
input_specs: Specs of the input tensor. | |
activation: name of the main activation function. | |
dropout_rate: Rate for dropout regularization. | |
kernel_initializer: Kernel initializer for the final dense layer. | |
kernel_regularizer: Kernel regularizer. | |
bias_regularizer: Bias regularizer. | |
output_states: if True, output intermediate states that can be used to run | |
the model in streaming mode. Inputting the output states of the | |
previous input clip with the current input clip will utilize a stream | |
buffer for streaming video. | |
**kwargs: Keyword arguments to be passed. | |
""" | |
if not input_specs: | |
input_specs = { | |
'image': tf_keras.layers.InputSpec(shape=[None, None, None, None, 3]) | |
} | |
self._num_classes = num_classes | |
self._input_specs = input_specs | |
self._activation = activation | |
self._dropout_rate = dropout_rate | |
self._kernel_initializer = kernel_initializer | |
self._kernel_regularizer = kernel_regularizer | |
self._bias_regularizer = bias_regularizer | |
self._output_states = output_states | |
self._encoder_dim = encoder_dim | |
state_specs = None | |
if backbone.use_external_states: | |
state_specs = backbone.initial_state_specs( | |
input_shape=input_specs['image'].shape) | |
inputs, outputs, vid_embed = self._build_network( | |
backbone, input_specs, state_specs=state_specs) | |
super(MovinetClassifier, self).__init__( | |
inputs=inputs, outputs={'prediction':outputs, 'vid_embedding':vid_embed}, **kwargs) | |
# Move backbone after super() call so Keras is happy | |
self._backbone = backbone | |
def _build_backbone( | |
self, | |
backbone: tf_keras.Model, | |
input_specs: Mapping[str, tf_keras.layers.InputSpec], | |
state_specs: Optional[Mapping[str, tf_keras.layers.InputSpec]] = None, | |
) -> Tuple[Mapping[str, Any], Any, Any]: | |
"""Builds the backbone network and gets states and endpoints. | |
Args: | |
backbone: the model backbone. | |
input_specs: the model input spec to use. | |
state_specs: a dict of states such that, if any of the keys match for a | |
layer, will overwrite the contents of the buffer(s). | |
Returns: | |
inputs: a dict of input specs. | |
endpoints: a dict of model endpoints. | |
states: a dict of model states. | |
""" | |
state_specs = state_specs if state_specs is not None else {} | |
states = { | |
name: tf_keras.Input(shape=spec.shape[1:], dtype=spec.dtype, name=name) | |
for name, spec in state_specs.items() | |
} | |
image = tf_keras.Input(shape=input_specs['image'].shape[1:], name='image') | |
inputs = {**states, 'image': image} | |
if backbone.use_external_states: | |
before_states = states | |
endpoints, states = backbone(inputs) | |
after_states = states | |
new_states = set(after_states) - set(before_states) | |
if new_states: | |
raise ValueError( | |
'Expected input and output states to be the same. Got extra states ' | |
'{}, expected {}'.format(new_states, set(before_states))) | |
mismatched_shapes = {} | |
for name in after_states: | |
before_shape = before_states[name].shape | |
after_shape = after_states[name].shape | |
if len(before_shape) != len(after_shape): | |
mismatched_shapes[name] = (before_shape, after_shape) | |
continue | |
for before, after in zip(before_shape, after_shape): | |
if before is not None and after is not None and before != after: | |
mismatched_shapes[name] = (before_shape, after_shape) | |
break | |
if mismatched_shapes: | |
raise ValueError( | |
'Got mismatched input and output state shapes: {}'.format( | |
mismatched_shapes)) | |
else: | |
endpoints, states = backbone(inputs) | |
return inputs, endpoints, states | |
def _build_network( | |
self, | |
backbone: tf_keras.Model, | |
input_specs: Mapping[str, tf_keras.layers.InputSpec], | |
state_specs: Optional[Mapping[str, tf_keras.layers.InputSpec]] = None, | |
) -> Tuple[Mapping[str, tf_keras.Input], Union[Tuple[Mapping[ # pytype: disable=invalid-annotation # typed-keras | |
str, tf.Tensor], Mapping[str, tf.Tensor]], Mapping[str, tf.Tensor]]]: | |
"""Builds the model network. | |
Args: | |
backbone: the model backbone. | |
input_specs: the model input spec to use. | |
state_specs: a dict of states such that, if any of the keys match for a | |
layer, will overwrite the contents of the buffer(s). | |
Returns: | |
Inputs and outputs as a tuple. Inputs are expected to be a dict with | |
base input and states. Outputs are expected to be a dict of endpoints | |
and (optionally) output states. | |
""" | |
inputs, endpoints, states = self._build_backbone( | |
backbone=backbone, input_specs=input_specs, state_specs=state_specs) | |
x = endpoints['block4_layer6'] | |
x, vid_embed = movinet_layers_a2_modified.ClassifierHead( | |
num_classes=self._num_classes, | |
encoder_dim=self._encoder_dim, | |
dropout_rate=self._dropout_rate, | |
kernel_initializer=self._kernel_initializer, | |
kernel_regularizer=self._kernel_regularizer, | |
conv_type='conv', | |
activation=self._activation)( | |
x) | |
# outputs = (x, vid_embed) if self._output_states else (x, vid_embed) | |
return inputs, x, vid_embed | |
def initial_state_specs( | |
self, input_shape: Sequence[int]) -> Dict[str, tf_keras.layers.InputSpec]: | |
return self._backbone.initial_state_specs(input_shape=input_shape) | |
def init_states(self, input_shape: Sequence[int]) -> Dict[str, tf.Tensor]: | |
"""Returns initial states for the first call in steaming mode.""" | |
return self._backbone.init_states(input_shape) | |
def checkpoint_items(self) -> Dict[str, Any]: | |
"""Returns a dictionary of items to be additionally checkpointed.""" | |
return dict(backbone=self.backbone) | |
def backbone(self) -> tf_keras.Model: | |
"""Returns the backbone of the model.""" | |
return self._backbone | |
def get_config(self): | |
config = { | |
'backbone': self._backbone, | |
'activation': self._activation, | |
'num_classes': self._num_classes, | |
'input_specs': self._input_specs, | |
'dropout_rate': self._dropout_rate, | |
'kernel_initializer': self._kernel_initializer, | |
'kernel_regularizer': self._kernel_regularizer, | |
'bias_regularizer': self._bias_regularizer, | |
'output_states': self._output_states, | |
} | |
return config | |
def from_config(cls, config, custom_objects=None): | |
# Each InputSpec may need to be deserialized | |
# This handles the case where we want to load a saved_model loaded with | |
# `tf_keras.models.load_model` | |
if config['input_specs']: | |
for name in config['input_specs']: | |
if isinstance(config['input_specs'][name], dict): | |
config['input_specs'][name] = tf_keras.layers.deserialize( | |
config['input_specs'][name]) | |
return cls(**config) | |
def build_movinet_model( | |
input_specs: Mapping[str, tf_keras.layers.InputSpec], | |
model_config: cfg.MovinetModel, | |
num_classes: int, | |
encoder_dim: int = 768, | |
l2_regularizer: Optional[tf_keras.regularizers.Regularizer] = None): | |
"""Builds movinet model.""" | |
logging.info('Building movinet model with num classes: %s', num_classes) | |
if l2_regularizer is not None: | |
logging.info('Building movinet model with regularizer: %s', | |
l2_regularizer.get_config()) | |
input_specs_dict = {'image': input_specs} | |
backbone = backbones.factory.build_backbone( | |
input_specs=input_specs, | |
backbone_config=model_config.backbone, | |
norm_activation_config=model_config.norm_activation, | |
l2_regularizer=l2_regularizer) | |
model = MovinetClassifier( | |
backbone, | |
num_classes=num_classes, | |
encoder_dim=encoder_dim, | |
kernel_regularizer=l2_regularizer, | |
input_specs=input_specs_dict, | |
activation=model_config.activation, | |
dropout_rate=model_config.dropout_rate, | |
output_states=model_config.output_states) | |
return model | |