deanna-emery's picture
updates
93528c6
raw
history blame
3.54 kB
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Factory methods to build models."""
# Import libraries
import tensorflow as tf, tf_keras
from official.core import registry
from official.vision.configs import video_classification as video_classification_cfg
from official.vision.modeling import video_classification_model
from official.vision.modeling import backbones
_REGISTERED_MODEL_CLS = {}
def register_model_builder(key: str):
"""Decorates a builder of model class.
The builder should be a Callable (a class or a function).
This decorator supports registration of backbone builder as follows:
```
class MyModel(tf_keras.Model):
pass
@register_backbone_builder('mybackbone')
def builder(input_specs, config, l2_reg):
return MyModel(...)
# Builds a MyModel object.
my_backbone = build_backbone_3d(input_specs, config, l2_reg)
```
Args:
key: the key to look up the builder.
Returns:
A callable for use as class decorator that registers the decorated class
for creation from an instance of model class.
"""
return registry.register(_REGISTERED_MODEL_CLS, key)
def build_model(
model_type: str,
input_specs: tf_keras.layers.InputSpec,
model_config: video_classification_cfg.hyperparams.Config,
num_classes: int,
l2_regularizer: tf_keras.regularizers.Regularizer = None) -> tf_keras.Model:
"""Builds backbone from a config.
Args:
model_type: string name of model type. It should be consistent with
ModelConfig.model_type.
input_specs: tf_keras.layers.InputSpec.
model_config: a OneOfConfig. Model config.
num_classes: number of classes.
l2_regularizer: tf_keras.regularizers.Regularizer instance. Default to None.
Returns:
tf_keras.Model instance of the backbone.
"""
model_builder = registry.lookup(_REGISTERED_MODEL_CLS, model_type)
return model_builder(input_specs, model_config, num_classes, l2_regularizer)
@register_model_builder('video_classification')
def build_video_classification_model(
input_specs: tf_keras.layers.InputSpec,
model_config: video_classification_cfg.VideoClassificationModel,
num_classes: int,
l2_regularizer: tf_keras.regularizers.Regularizer = None) -> tf_keras.Model:
"""Builds the video classification model."""
input_specs_dict = {'image': input_specs}
norm_activation_config = model_config.norm_activation
backbone = backbones.factory.build_backbone(
input_specs=input_specs,
backbone_config=model_config.backbone,
norm_activation_config=norm_activation_config,
l2_regularizer=l2_regularizer)
model = video_classification_model.VideoClassificationModel(
backbone=backbone,
num_classes=num_classes,
input_specs=input_specs_dict,
dropout_rate=model_config.dropout_rate,
aggregate_endpoints=model_config.aggregate_endpoints,
kernel_regularizer=l2_regularizer,
require_endpoints=model_config.require_endpoints)
return model