# Copyright 2023 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Definitions for MoViNet structures. Reference: "MoViNets: Mobile Video Networks for Efficient Video Recognition" https://arxiv.org/pdf/2103.11511.pdf MoViNets are efficient video classification networks that are part of a model family, ranging from the smallest model, MoViNet-A0, to the largest model, MoViNet-A6. Each model has various width, depth, input resolution, and input frame-rate associated with them. See the main paper for more details. """ import dataclasses from official.core import config_definitions as cfg from official.core import exp_factory from official.modeling import hyperparams from official.vision.configs import backbones_3d from official.vision.configs import common from official.vision.configs import video_classification @dataclasses.dataclass class Movinet(hyperparams.Config): """Backbone config for Base MoViNet.""" model_id: str = 'a0' causal: bool = False use_positional_encoding: bool = False # Choose from ['3d', '2plus1d', '3d_2plus1d'] # 3d: default 3D convolution # 2plus1d: (2+1)D convolution with Conv2D (2D reshaping) # 3d_2plus1d: (2+1)D convolution with Conv3D (no 2D reshaping) conv_type: str = '3d' # Choose from ['3d', '2d', '2plus3d'] # 3d: default 3D global average pooling. # 2d: 2D global average pooling. # 2plus3d: concatenation of 2D and 3D global average pooling. se_type: str = '3d' activation: str = 'swish' gating_activation: str = 'sigmoid' stochastic_depth_drop_rate: float = 0.2 use_external_states: bool = False average_pooling_type: str = '3d' output_states: bool = True @dataclasses.dataclass class MovinetA0(Movinet): """Backbone config for MoViNet-A0. Represents the smallest base MoViNet searched by NAS. Reference: https://arxiv.org/pdf/2103.11511.pdf """ model_id: str = 'a0' @dataclasses.dataclass class MovinetA1(Movinet): """Backbone config for MoViNet-A1.""" model_id: str = 'a1' @dataclasses.dataclass class MovinetA2(Movinet): """Backbone config for MoViNet-A2.""" model_id: str = 'a2' @dataclasses.dataclass class MovinetA3(Movinet): """Backbone config for MoViNet-A3.""" model_id: str = 'a3' @dataclasses.dataclass class MovinetA4(Movinet): """Backbone config for MoViNet-A4.""" model_id: str = 'a4' @dataclasses.dataclass class MovinetA5(Movinet): """Backbone config for MoViNet-A5. Represents the largest base MoViNet searched by NAS. """ model_id: str = 'a5' @dataclasses.dataclass class MovinetT0(Movinet): """Backbone config for MoViNet-T0. MoViNet-T0 is a smaller version of MoViNet-A0 for even faster processing. """ model_id: str = 't0' @dataclasses.dataclass class Backbone3D(backbones_3d.Backbone3D): """Configuration for backbones. Attributes: type: 'str', type of backbone be used, on the of fields below. movinet: movinet backbone config. """ type: str = 'movinet' movinet: Movinet = dataclasses.field(default_factory=Movinet) @dataclasses.dataclass class MovinetModel(video_classification.VideoClassificationModel): """The MoViNet model config.""" model_type: str = 'movinet' backbone: Backbone3D = dataclasses.field(default_factory=Backbone3D) norm_activation: common.NormActivation = dataclasses.field( default_factory=lambda: common.NormActivation( # pylint: disable=g-long-lambda activation=None, # legacy flag, not used. norm_momentum=0.99, norm_epsilon=1e-3, use_sync_bn=True, ) ) activation: str = 'swish' output_states: bool = False @exp_factory.register_config_factory('movinet_kinetics600') def movinet_kinetics600() -> cfg.ExperimentConfig: """Video classification on Videonet with MoViNet backbone.""" exp = video_classification.video_classification_kinetics600() exp.task.train_data.dtype = 'bfloat16' exp.task.validation_data.dtype = 'bfloat16' model = MovinetModel() exp.task.model = model return exp