# Copyright 2023 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Build Movinet for video classification. Reference: https://arxiv.org/pdf/2103.11511.pdf """ from typing import Any, Dict, Mapping, Optional, Sequence, Tuple, Union from absl import logging import tensorflow as tf, tf_keras from official.projects.movinet.configs import movinet as cfg from official.projects.movinet.modeling import movinet_layers_a2_modified from official.vision.modeling import backbones from official.vision.modeling import factory_3d as model_factory @tf_keras.utils.register_keras_serializable(package='Vision') class MovinetClassifier(tf_keras.Model): """A video classification class builder.""" def __init__( self, backbone: tf_keras.Model, num_classes: int, encoder_dim: int = 768, input_specs: Optional[Mapping[str, tf_keras.layers.InputSpec]] = None, activation: str = 'swish', dropout_rate: float = 0.1, kernel_initializer: str = 'HeNormal', kernel_regularizer: Optional[tf_keras.regularizers.Regularizer] = None, bias_regularizer: Optional[tf_keras.regularizers.Regularizer] = None, output_states: bool = False, **kwargs): """Movinet initialization function. Args: backbone: A 3d backbone network. num_classes: Number of classes in classification task. input_specs: Specs of the input tensor. activation: name of the main activation function. dropout_rate: Rate for dropout regularization. kernel_initializer: Kernel initializer for the final dense layer. kernel_regularizer: Kernel regularizer. bias_regularizer: Bias regularizer. output_states: if True, output intermediate states that can be used to run the model in streaming mode. Inputting the output states of the previous input clip with the current input clip will utilize a stream buffer for streaming video. **kwargs: Keyword arguments to be passed. """ if not input_specs: input_specs = { 'image': tf_keras.layers.InputSpec(shape=[None, None, None, None, 3]) } self._num_classes = num_classes self._input_specs = input_specs self._activation = activation self._dropout_rate = dropout_rate self._kernel_initializer = kernel_initializer self._kernel_regularizer = kernel_regularizer self._bias_regularizer = bias_regularizer self._output_states = output_states self._encoder_dim = encoder_dim state_specs = None if backbone.use_external_states: state_specs = backbone.initial_state_specs( input_shape=input_specs['image'].shape) inputs, outputs, vid_embed = self._build_network( backbone, input_specs, state_specs=state_specs) super(MovinetClassifier, self).__init__( inputs=inputs, outputs={'prediction':outputs, 'vid_embedding':vid_embed}, **kwargs) # Move backbone after super() call so Keras is happy self._backbone = backbone def _build_backbone( self, backbone: tf_keras.Model, input_specs: Mapping[str, tf_keras.layers.InputSpec], state_specs: Optional[Mapping[str, tf_keras.layers.InputSpec]] = None, ) -> Tuple[Mapping[str, Any], Any, Any]: """Builds the backbone network and gets states and endpoints. Args: backbone: the model backbone. input_specs: the model input spec to use. state_specs: a dict of states such that, if any of the keys match for a layer, will overwrite the contents of the buffer(s). Returns: inputs: a dict of input specs. endpoints: a dict of model endpoints. states: a dict of model states. """ state_specs = state_specs if state_specs is not None else {} states = { name: tf_keras.Input(shape=spec.shape[1:], dtype=spec.dtype, name=name) for name, spec in state_specs.items() } image = tf_keras.Input(shape=input_specs['image'].shape[1:], name='image') inputs = {**states, 'image': image} if backbone.use_external_states: before_states = states endpoints, states = backbone(inputs) after_states = states new_states = set(after_states) - set(before_states) if new_states: raise ValueError( 'Expected input and output states to be the same. Got extra states ' '{}, expected {}'.format(new_states, set(before_states))) mismatched_shapes = {} for name in after_states: before_shape = before_states[name].shape after_shape = after_states[name].shape if len(before_shape) != len(after_shape): mismatched_shapes[name] = (before_shape, after_shape) continue for before, after in zip(before_shape, after_shape): if before is not None and after is not None and before != after: mismatched_shapes[name] = (before_shape, after_shape) break if mismatched_shapes: raise ValueError( 'Got mismatched input and output state shapes: {}'.format( mismatched_shapes)) else: endpoints, states = backbone(inputs) return inputs, endpoints, states def _build_network( self, backbone: tf_keras.Model, input_specs: Mapping[str, tf_keras.layers.InputSpec], state_specs: Optional[Mapping[str, tf_keras.layers.InputSpec]] = None, ) -> Tuple[Mapping[str, tf_keras.Input], Union[Tuple[Mapping[ # pytype: disable=invalid-annotation # typed-keras str, tf.Tensor], Mapping[str, tf.Tensor]], Mapping[str, tf.Tensor]]]: """Builds the model network. Args: backbone: the model backbone. input_specs: the model input spec to use. state_specs: a dict of states such that, if any of the keys match for a layer, will overwrite the contents of the buffer(s). Returns: Inputs and outputs as a tuple. Inputs are expected to be a dict with base input and states. Outputs are expected to be a dict of endpoints and (optionally) output states. """ inputs, endpoints, states = self._build_backbone( backbone=backbone, input_specs=input_specs, state_specs=state_specs) x = endpoints['block4_layer6'] x, vid_embed = movinet_layers_a2_modified.ClassifierHead( num_classes=self._num_classes, encoder_dim=self._encoder_dim, dropout_rate=self._dropout_rate, kernel_initializer=self._kernel_initializer, kernel_regularizer=self._kernel_regularizer, conv_type='conv', activation=self._activation)( x) # outputs = (x, vid_embed) if self._output_states else (x, vid_embed) return inputs, x, vid_embed def initial_state_specs( self, input_shape: Sequence[int]) -> Dict[str, tf_keras.layers.InputSpec]: return self._backbone.initial_state_specs(input_shape=input_shape) @tf.function def init_states(self, input_shape: Sequence[int]) -> Dict[str, tf.Tensor]: """Returns initial states for the first call in steaming mode.""" return self._backbone.init_states(input_shape) @property def checkpoint_items(self) -> Dict[str, Any]: """Returns a dictionary of items to be additionally checkpointed.""" return dict(backbone=self.backbone) @property def backbone(self) -> tf_keras.Model: """Returns the backbone of the model.""" return self._backbone def get_config(self): config = { 'backbone': self._backbone, 'activation': self._activation, 'num_classes': self._num_classes, 'input_specs': self._input_specs, 'dropout_rate': self._dropout_rate, 'kernel_initializer': self._kernel_initializer, 'kernel_regularizer': self._kernel_regularizer, 'bias_regularizer': self._bias_regularizer, 'output_states': self._output_states, } return config @classmethod def from_config(cls, config, custom_objects=None): # Each InputSpec may need to be deserialized # This handles the case where we want to load a saved_model loaded with # `tf_keras.models.load_model` if config['input_specs']: for name in config['input_specs']: if isinstance(config['input_specs'][name], dict): config['input_specs'][name] = tf_keras.layers.deserialize( config['input_specs'][name]) return cls(**config) @model_factory.register_model_builder('movinet') def build_movinet_model( input_specs: Mapping[str, tf_keras.layers.InputSpec], model_config: cfg.MovinetModel, num_classes: int, encoder_dim: int = 768, l2_regularizer: Optional[tf_keras.regularizers.Regularizer] = None): """Builds movinet model.""" logging.info('Building movinet model with num classes: %s', num_classes) if l2_regularizer is not None: logging.info('Building movinet model with regularizer: %s', l2_regularizer.get_config()) input_specs_dict = {'image': input_specs} backbone = backbones.factory.build_backbone( input_specs=input_specs, backbone_config=model_config.backbone, norm_activation_config=model_config.norm_activation, l2_regularizer=l2_regularizer) model = MovinetClassifier( backbone, num_classes=num_classes, encoder_dim=encoder_dim, kernel_regularizer=l2_regularizer, input_specs=input_specs_dict, activation=model_config.activation, dropout_rate=model_config.dropout_rate, output_states=model_config.output_states) return model