# Copyright 2023 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Util functions for loading checkpoints. Especially for loading Tensorflow 1.x checkpoint to Tensorflow 2.x (keras) model. """ from __future__ import absolute_import from __future__ import division from __future__ import print_function import re from absl import logging import tensorflow as tf, tf_keras def _build_assignment_map(keras_model, prefix='', skip_variables_regex=None, var_to_shape_map=None): """Builds the variable assignment map. Compute an assignment mapping for loading older checkpoints into a Keras model. Variable names are remapped from the original TPUEstimator model to the new Keras name. Args: keras_model: tf_keras.Model object to provide variables to assign. prefix: prefix in the variable name to be remove for alignment with names in the checkpoint. skip_variables_regex: regular expression to math the names of variables that do not need to be assign. var_to_shape_map: variable name to shape mapping from the checkpoint. Returns: The variable assignment map. """ assignment_map = {} checkpoint_names = [] if var_to_shape_map: # pylint: disable=g-long-lambda checkpoint_names = list( filter( lambda x: not x.endswith('Momentum') and not x.endswith( 'global_step'), var_to_shape_map.keys())) # pylint: enable=g-long-lambda logging.info('Number of variables in the checkpoint %d', len(checkpoint_names)) for var in keras_model.variables: var_name = var.name if skip_variables_regex and re.match(skip_variables_regex, var_name): continue # Trim the index of the variable. if ':' in var_name: var_name = var_name[:var_name.rindex(':')] if var_name.startswith(prefix): var_name = var_name[len(prefix):] if not var_to_shape_map: assignment_map[var_name] = var continue # Match name with variables in the checkpoint. # pylint: disable=cell-var-from-loop match_names = list(filter(lambda x: x.endswith(var_name), checkpoint_names)) # pylint: enable=cell-var-from-loop try: if match_names: assert len(match_names) == 1, 'more then on matches for {}: {}'.format( var_name, match_names) checkpoint_names.remove(match_names[0]) assignment_map[match_names[0]] = var else: logging.info('Error not found var name: %s', var_name) except Exception as e: logging.info('Error removing the match_name: %s', match_names) logging.info('Exception: %s', e) raise logging.info('Found matching variable in checkpoint: %d', len(assignment_map)) return assignment_map def _get_checkpoint_map(checkpoint_path): reader = tf.train.load_checkpoint(checkpoint_path) return reader.get_variable_to_shape_map() def make_restore_checkpoint_fn(checkpoint_path, prefix='', skip_regex=None): """Returns scaffold function to restore parameters from v1 checkpoint. Args: checkpoint_path: path of the checkpoint folder or file. Example 1: '/path/to/model_dir/' Example 2: '/path/to/model.ckpt-22500' prefix: prefix in the variable name to be remove for alignment with names in the checkpoint. skip_regex: regular expression to math the names of variables that do not need to be assign. Returns: Callable[tf.kears.Model] -> void. Fn to load v1 checkpoint to keras model. """ def _restore_checkpoint_fn(keras_model): """Loads pretrained model through scaffold function.""" if not checkpoint_path: logging.info('checkpoint_path is empty') return var_prefix = prefix if prefix and not prefix.endswith('/'): var_prefix += '/' var_to_shape_map = _get_checkpoint_map(checkpoint_path) assert var_to_shape_map, 'var_to_shape_map should not be empty' vars_to_load = _build_assignment_map( keras_model, prefix=var_prefix, skip_variables_regex=skip_regex, var_to_shape_map=var_to_shape_map) if not vars_to_load: raise ValueError('Variables to load is empty.') tf.compat.v1.train.init_from_checkpoint(checkpoint_path, vars_to_load) return _restore_checkpoint_fn