|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Base Model definition.""" |
|
|
|
from __future__ import absolute_import |
|
from __future__ import division |
|
from __future__ import print_function |
|
|
|
import abc |
|
import functools |
|
import re |
|
import tensorflow as tf |
|
from official.vision.detection.modeling import checkpoint_utils |
|
from official.vision.detection.modeling import learning_rates |
|
from official.vision.detection.modeling import optimizers |
|
|
|
|
|
def _make_filter_trainable_variables_fn(frozen_variable_prefix): |
|
"""Creates a function for filtering trainable varialbes.""" |
|
|
|
def _filter_trainable_variables(variables): |
|
"""Filters trainable varialbes. |
|
|
|
Args: |
|
variables: a list of tf.Variable to be filtered. |
|
|
|
Returns: |
|
filtered_variables: a list of tf.Variable filtered out the frozen ones. |
|
""" |
|
|
|
|
|
filtered_variables = [ |
|
v for v in variables |
|
if not frozen_variable_prefix or |
|
not re.match(frozen_variable_prefix, v.name) |
|
] |
|
return filtered_variables |
|
|
|
return _filter_trainable_variables |
|
|
|
|
|
class Model(object): |
|
"""Base class for model function.""" |
|
|
|
__metaclass__ = abc.ABCMeta |
|
|
|
def __init__(self, params): |
|
self._use_bfloat16 = params.architecture.use_bfloat16 |
|
|
|
if params.architecture.use_bfloat16: |
|
policy = tf.compat.v2.keras.mixed_precision.experimental.Policy( |
|
'mixed_bfloat16') |
|
tf.compat.v2.keras.mixed_precision.experimental.set_policy(policy) |
|
|
|
|
|
self._optimizer_fn = optimizers.OptimizerFactory(params.train.optimizer) |
|
self._learning_rate = learning_rates.learning_rate_generator( |
|
params.train.total_steps, params.train.learning_rate) |
|
|
|
self._frozen_variable_prefix = params.train.frozen_variable_prefix |
|
self._regularization_var_regex = params.train.regularization_variable_regex |
|
self._l2_weight_decay = params.train.l2_weight_decay |
|
|
|
|
|
self._checkpoint = params.train.checkpoint.as_dict() |
|
|
|
|
|
self._enable_summary = params.enable_summary |
|
self._model_dir = params.model_dir |
|
|
|
@abc.abstractmethod |
|
def build_outputs(self, inputs, mode): |
|
"""Build the graph of the forward path.""" |
|
pass |
|
|
|
@abc.abstractmethod |
|
def build_model(self, params, mode): |
|
"""Build the model object.""" |
|
pass |
|
|
|
@abc.abstractmethod |
|
def build_loss_fn(self): |
|
"""Build the model object.""" |
|
pass |
|
|
|
def post_processing(self, labels, outputs): |
|
"""Post-processing function.""" |
|
return labels, outputs |
|
|
|
def model_outputs(self, inputs, mode): |
|
"""Build the model outputs.""" |
|
return self.build_outputs(inputs, mode) |
|
|
|
def build_optimizer(self): |
|
"""Returns train_op to optimize total loss.""" |
|
|
|
return self._optimizer_fn(self._learning_rate) |
|
|
|
def make_filter_trainable_variables_fn(self): |
|
"""Creates a function for filtering trainable varialbes.""" |
|
return _make_filter_trainable_variables_fn(self._frozen_variable_prefix) |
|
|
|
def weight_decay_loss(self, trainable_variables): |
|
reg_variables = [ |
|
v for v in trainable_variables |
|
if self._regularization_var_regex is None |
|
or re.match(self._regularization_var_regex, v.name) |
|
] |
|
|
|
return self._l2_weight_decay * tf.add_n( |
|
[tf.nn.l2_loss(v) for v in reg_variables]) |
|
|
|
def make_restore_checkpoint_fn(self): |
|
"""Returns scaffold function to restore parameters from v1 checkpoint.""" |
|
if 'skip_checkpoint_variables' in self._checkpoint: |
|
skip_regex = self._checkpoint['skip_checkpoint_variables'] |
|
else: |
|
skip_regex = None |
|
return checkpoint_utils.make_restore_checkpoint_fn( |
|
self._checkpoint['path'], |
|
prefix=self._checkpoint['prefix'], |
|
skip_regex=skip_regex) |
|
|
|
def eval_metrics(self): |
|
"""Returns tuple of metric function and its inputs for evaluation.""" |
|
raise NotImplementedError('Unimplemented eval_metrics') |
|
|