deanna-emery's picture
updates
93528c6
raw
history blame
22.1 kB
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Training utils."""
import dataclasses
import inspect
import json
import os
import pprint
from typing import Any, Callable, Dict, List, Optional, Union
from absl import logging
import gin
import numpy as np
import orbit
import tensorflow as tf, tf_keras
# pylint: disable=g-direct-tensorflow-import
from tensorflow.python.framework import ops
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2_as_graph
# pylint: enable=g-direct-tensorflow-import
from official.core import base_task
from official.core import base_trainer
from official.core import config_definitions
from official.core import exp_factory
from official.modeling import hyperparams
BEST_CHECKPOINT_NAME = 'best_ckpt'
def get_leaf_nested_dict(d: Dict[str, Any], keys: List[str]) -> Dict[str, Any]:
"""Get leaf from a dictionary with arbitrary depth with a list of keys.
Args:
d: The dictionary to extract value from.
keys: The list of keys to extract values recursively.
Returns:
The value of the leaf.
Raises:
KeyError: If the value of keys extracted is a dictionary.
"""
leaf = d
for k in keys:
if not isinstance(leaf, dict) or k not in leaf:
raise KeyError(
'Path not exist while traversing the dictionary: d with keys'
': %s.' % keys)
leaf = leaf[k]
if isinstance(leaf, dict):
raise KeyError('The value extracted with keys: %s is not a leaf of the '
'dictionary: %s.' % (keys, d))
return leaf
def cast_leaf_nested_dict(d: Dict[str, Any],
cast_fn: Callable[[Any], Any]) -> Dict[str, Any]:
"""Cast the leaves of a dictionary with arbitrary depth in place.
Args:
d: The dictionary to extract value from.
cast_fn: The casting function.
Returns:
A dictionray with the same structure as d.
"""
for key, value in d.items():
if isinstance(value, dict):
d[key] = cast_leaf_nested_dict(value, cast_fn)
else:
d[key] = cast_fn(value)
return d
def _filter_leaf_nested_dict(
d: Dict[str, Any], predicate: Callable[[Any], bool]
) -> Dict[str, Any]:
"""Filters the leaves of a dictionary with arbitrary depth in place.
Args:
d: The dictionary to extract value from.
predicate: A function that will be called on every leave item. When the
function returns True the leave will be kept. Otherwise the leave will be
dropped.
Returns:
A new dictionray with filtered result.
"""
result = {}
for key, value in d.items():
if isinstance(value, dict):
result[key] = _filter_leaf_nested_dict(value, predicate)
elif predicate(value):
result[key] = value
return result
def maybe_create_best_ckpt_exporter(params: config_definitions.ExperimentConfig,
data_dir: str) -> Any:
"""Maybe create a BestCheckpointExporter object, according to the config."""
export_subdir = params.trainer.best_checkpoint_export_subdir
metric_name = params.trainer.best_checkpoint_eval_metric
metric_comp = params.trainer.best_checkpoint_metric_comp
if data_dir and export_subdir and metric_name:
best_ckpt_dir = os.path.join(data_dir, export_subdir)
best_ckpt_exporter = BestCheckpointExporter(best_ckpt_dir, metric_name,
metric_comp)
logging.info(
'Created the best checkpoint exporter. '
'data_dir: %s, export_subdir: %s, metric_name: %s', data_dir,
export_subdir, metric_name)
else:
best_ckpt_exporter = None
return best_ckpt_exporter
class BestCheckpointExporter:
"""Keeps track of the best result, and saves its checkpoint.
Orbit will support an API for checkpoint exporter. This class will be used
together with orbit once this functionality is ready.
"""
def __init__(self, export_dir: str, metric_name: str, metric_comp: str):
"""Initialization.
Args:
export_dir: The directory that will contain exported checkpoints.
metric_name: Indicates which metric to look at, when determining which
result is better. If eval_logs being passed to maybe_export_checkpoint
is a nested dictionary, use `|` as a seperator for different layers.
metric_comp: Indicates how to compare results. Either `lower` or `higher`.
"""
self._export_dir = export_dir
self._metric_name = metric_name.split('|')
self._metric_comp = metric_comp
if self._metric_comp not in ('lower', 'higher'):
raise ValueError('best checkpoint metric comp must be one of '
'higher, lower. Got: {}'.format(self._metric_comp))
tf.io.gfile.makedirs(os.path.dirname(self.best_ckpt_logs_path))
self._best_ckpt_logs = self._maybe_load_best_eval_metric()
self._checkpoint_manager = None
def _get_checkpoint_manager(self, checkpoint):
"""Gets an existing checkpoint manager or creates a new one."""
if self._checkpoint_manager is None or (self._checkpoint_manager.checkpoint
!= checkpoint):
logging.info('Creates a new checkpoint manager.')
self._checkpoint_manager = tf.train.CheckpointManager(
checkpoint,
directory=self._export_dir,
max_to_keep=1,
checkpoint_name=BEST_CHECKPOINT_NAME)
return self._checkpoint_manager
def maybe_export_checkpoint(
self, checkpoint, eval_logs, global_step, write_logs=True) -> bool:
"""Compare eval_logs with past eval_logs and export checkpoint if better."""
logging.info('[BestCheckpointExporter] received eval_logs: %s, at step: %d',
eval_logs, global_step)
if self._best_ckpt_logs is None or self._new_metric_is_better(
self._best_ckpt_logs, eval_logs):
self._best_ckpt_logs = eval_logs
if write_logs:
self.export_best_eval_metric(self._best_ckpt_logs, global_step)
self._get_checkpoint_manager(checkpoint).save()
return True
return False
def _maybe_load_best_eval_metric(self):
if not tf.io.gfile.exists(self.best_ckpt_logs_path):
return None
with tf.io.gfile.GFile(self.best_ckpt_logs_path, 'r') as reader:
return json.loads(reader.read())
def _new_metric_is_better(self, old_logs, new_logs):
"""Check if the metric in new_logs is better than the metric in old_logs."""
old_value = float(
orbit.utils.get_value(
get_leaf_nested_dict(old_logs, self._metric_name)))
new_value = float(
orbit.utils.get_value(
get_leaf_nested_dict(new_logs, self._metric_name)))
logging.info('[BestCheckpointExporter] comparing results. old: %f, new: %f',
old_value, new_value)
if self._metric_comp == 'higher':
if new_value > old_value:
logging.info('[BestCheckpointExporter] '
'the new number is better since it is higher.')
return True
else: # self._metric_comp == 'lower':
if new_value < old_value:
logging.info('[BestCheckpointExporter] '
'the new number is better since it is lower.')
return True
return False
def export_best_eval_metric(self, eval_logs, global_step):
"""Export evaluation results of the best checkpoint into a json file."""
# eval_log_ext may contains non-scalar tensors, such as image data when
# `allow_image_summary` is True. Here we only keep scalar tensors.
eval_logs_ext = _filter_leaf_nested_dict(
eval_logs, lambda x: tf.rank(x) <= 1
)
eval_logs_ext['best_ckpt_global_step'] = global_step
eval_logs_ext = cast_leaf_nested_dict(
eval_logs_ext, lambda x: float(orbit.utils.get_value(x)))
# Saving json file is very fast.
with tf.io.gfile.GFile(self.best_ckpt_logs_path, 'w') as writer:
writer.write(json.dumps(eval_logs_ext, indent=4) + '\n')
@property
def best_ckpt_logs(self):
return self._best_ckpt_logs
@property
def best_ckpt_logs_path(self):
return os.path.join(self._export_dir, 'info.json')
@property
def best_ckpt_path(self):
"""Returns the best ckpt path or None if there is no ckpt yet."""
return tf.train.latest_checkpoint(self._export_dir)
def create_optimizer(task: base_task.Task,
params: config_definitions.ExperimentConfig
) -> tf_keras.optimizers.Optimizer:
"""A create optimizer util to be backward compatability with new args."""
if 'dp_config' in inspect.signature(task.create_optimizer).parameters:
dp_config = None
if hasattr(params.task, 'differential_privacy_config'):
dp_config = params.task.differential_privacy_config
optimizer = task.create_optimizer(
params.trainer.optimizer_config, params.runtime,
dp_config=dp_config)
else:
if hasattr(params.task, 'differential_privacy_config'
) and params.task.differential_privacy_config is not None:
raise ValueError('Differential privacy config is specified but '
'task.create_optimizer api does not accept it.')
optimizer = task.create_optimizer(
params.trainer.optimizer_config,
params.runtime)
return optimizer
@gin.configurable
def create_trainer(params: config_definitions.ExperimentConfig,
task: base_task.Task,
train: bool,
evaluate: bool,
checkpoint_exporter: Optional[BestCheckpointExporter] = None,
trainer_cls=base_trainer.Trainer) -> base_trainer.Trainer:
"""Create trainer."""
logging.info('Running default trainer.')
model = task.build_model()
optimizer = create_optimizer(task, params)
return trainer_cls(
params,
task,
model=model,
optimizer=optimizer,
train=train,
evaluate=evaluate,
checkpoint_exporter=checkpoint_exporter)
@dataclasses.dataclass
class ParseConfigOptions:
"""Use this dataclass instead of FLAGS to customize parse_configuration()."""
experiment: str
config_file: List[str]
tpu: str = ''
tf_data_service: str = ''
params_override: str = ''
def __contains__(self, name):
return name in dataclasses.asdict(self)
class ExperimentParser:
"""Constructs the Experiment config from Flags or equivalent object.
Most of the cases, users only need to call the `parse()` function:
```
builder = ExperimentParser(FLAGS)
params = builder.parse()
```
The advanced users can modify the flow by calling the parse_*() functions
separately.
"""
def __init__(self, flags_obj):
self._flags_obj = flags_obj
def parse(self):
"""Overrall process of constructing Experiment config."""
params = self.base_experiment()
params = self.parse_config_file(params)
params = self.parse_runtime(params)
params = self.parse_data_service(params)
params = self.parse_params_override(params)
return params
def base_experiment(self):
"""Get the base experiment config from --experiment field."""
if self._flags_obj.experiment is None:
raise ValueError('The flag --experiment must be specified.')
return exp_factory.get_exp_config(self._flags_obj.experiment)
def parse_config_file(self, params):
"""Override the configs of params from the config_file."""
for config_file in self._flags_obj.config_file or []:
params = hyperparams.override_params_dict(
params, config_file, is_strict=True)
return params
def parse_runtime(self, params):
"""Override the runtime configs of params from flags."""
# Override the TPU address and tf.data service address.
params.override({
'runtime': {
'tpu': self._flags_obj.tpu,
},
})
return params
def parse_data_service(self, params):
"""Override the data service configs of params from flags."""
if ('tf_data_service' in self._flags_obj and
self._flags_obj.tf_data_service and
isinstance(params.task, config_definitions.TaskConfig)):
params.override({
'task': {
'train_data': {
'tf_data_service_address': self._flags_obj.tf_data_service,
},
'validation_data': {
'tf_data_service_address': self._flags_obj.tf_data_service,
}
}
})
return params
def parse_params_override(self, params):
# Get the second level of override from `--params_override`.
# `--params_override` is typically used as a further override over the
# template. For example, one may define a particular template for training
# ResNet50 on ImageNet in a config file and pass it via `--config_file`,
# then define different learning rates and pass it via `--params_override`.
if self._flags_obj.params_override:
params = hyperparams.override_params_dict(
params, self._flags_obj.params_override, is_strict=True)
return params
def parse_configuration(flags_obj, lock_return=True, print_return=True):
"""Parses ExperimentConfig from flags."""
params = ExperimentParser(flags_obj).parse()
params.validate()
if lock_return:
params.lock()
if print_return:
pp = pprint.PrettyPrinter()
logging.info('Final experiment parameters:\n%s',
pp.pformat(params.as_dict()))
return params
def serialize_config(params: config_definitions.ExperimentConfig,
model_dir: str):
"""Serializes and saves the experiment config."""
if model_dir is None:
raise ValueError('model_dir must be specified, but got None')
params_save_path = os.path.join(model_dir, 'params.yaml')
logging.info('Saving experiment configuration to %s', params_save_path)
tf.io.gfile.makedirs(model_dir)
hyperparams.save_params_dict_to_yaml(params, params_save_path)
def save_gin_config(filename_suffix: str, model_dir: str):
"""Serializes and saves the experiment config."""
gin_save_path = os.path.join(
model_dir, 'operative_config.{}.gin'.format(filename_suffix))
logging.info('Saving gin configurations to %s', gin_save_path)
tf.io.gfile.makedirs(model_dir)
with tf.io.gfile.GFile(gin_save_path, 'w') as f:
f.write(gin.operative_config_str())
def read_global_step_from_checkpoint(ckpt_file_path):
"""Read global step from checkpoint, or get global step from its filename."""
global_step = tf.Variable(-1, dtype=tf.int64)
ckpt = tf.train.Checkpoint(global_step=global_step)
try:
ckpt.restore(ckpt_file_path).expect_partial()
global_step_maybe_restored = global_step.numpy()
except tf.errors.InvalidArgumentError:
global_step_maybe_restored = -1
if global_step_maybe_restored == -1:
raise ValueError('global_step not found in checkpoint {}. '
'If you want to run finetune eval jobs, you need to '
'make sure that your pretrain model writes '
'global_step in its checkpoints.'.format(ckpt_file_path))
global_step_restored = global_step.numpy()
logging.info('get global_step %d from checkpoint %s', global_step_restored,
ckpt_file_path)
return global_step_restored
def write_json_summary(log_dir, global_step, eval_metrics):
"""Dump evaluation metrics to json file."""
serializable_dict = {}
for name, value in eval_metrics.items():
if hasattr(value, 'numpy'):
serializable_dict[name] = str(value.numpy())
else:
serializable_dict[name] = str(value)
output_json = os.path.join(log_dir, 'metrics-{}.json'.format(global_step))
logging.info('Evaluation results at pretrain step %d: %s', global_step,
serializable_dict)
with tf.io.gfile.GFile(output_json, 'w') as writer:
writer.write(json.dumps(serializable_dict, indent=4) + '\n')
def write_summary(summary_writer, global_step, eval_metrics):
"""Write evaluation metrics to TF summary."""
numeric_dict = {}
for name, value in eval_metrics.items():
numeric_dict[name] = float(orbit.utils.get_value(value))
with summary_writer.as_default():
for name, value in numeric_dict.items():
tf.summary.scalar(name, value, step=global_step)
summary_writer.flush()
def remove_ckpts(model_dir):
"""Remove model checkpoints, so we can restart."""
ckpts = os.path.join(model_dir, 'ckpt-*')
logging.info('removing checkpoint files %s', ckpts)
for file_to_remove in tf.io.gfile.glob(ckpts):
tf.io.gfile.rmtree(file_to_remove)
file_to_remove = os.path.join(model_dir, 'checkpoint')
if tf.io.gfile.exists(file_to_remove):
tf.io.gfile.remove(file_to_remove)
def write_model_params(model: Union[tf.Module, tf_keras.Model],
output_path: str) -> None:
"""Writes the model parameters and shapes to a file.
Args:
model: A model instance.
output_path: Output file path.
"""
with tf.io.gfile.GFile(output_path, 'w') as f:
total_params = 0
for var in model.variables:
shape = tf.shape(var)
total_params += tf.math.reduce_prod(shape).numpy()
f.write(f'{var.name} {shape.numpy().tolist()}\n')
f.write(f'\nTotal params: {total_params}\n')
def try_count_params(
model: Union[tf.Module, tf_keras.Model],
trainable_only: bool = False):
"""Count the number of parameters if model is possible.
Args:
model: Try to count the number of params in this model.
trainable_only: Whether to calculate trainable params only. This flag is
not used when the model has `count_params` attribute.
Returns:
The number of parameters or None.
"""
if hasattr(model, 'count_params'):
try:
return model.count_params()
except ValueError:
logging.info('Number of trainable params unknown, because the build() '
'methods in keras layers were not called. This is probably '
'because the model was not feed any input, e.g., the max '
'train step already reached before this run.')
return None
else:
total_params = 0
variables = model.trainable_variables if trainable_only else model.variables
for var in variables:
shape = tf.shape(var)
total_params += tf.math.reduce_prod(shape).numpy()
return total_params
def try_count_flops(model: Union[tf.Module, tf_keras.Model],
inputs_kwargs: Optional[Dict[str, Any]] = None,
output_path: Optional[str] = None):
"""Counts and returns model FLOPs.
Args:
model: A model instance.
inputs_kwargs: An optional dictionary of argument pairs specifying inputs'
shape specifications to getting corresponding concrete function.
output_path: A file path to write the profiling results to.
Returns:
The model's FLOPs.
"""
if hasattr(model, 'inputs'):
try:
# Get input shape and set batch size to 1.
if model.inputs:
inputs = [
tf.TensorSpec([1] + input.shape[1:], input.dtype)
for input in model.inputs
]
concrete_func = tf.function(model).get_concrete_function(inputs)
# If model.inputs is invalid, try to use the input to get concrete
# function for model.call (subclass model).
else:
concrete_func = tf.function(model.call).get_concrete_function(
**inputs_kwargs)
frozen_func, _ = convert_variables_to_constants_v2_as_graph(concrete_func)
# Calculate FLOPs.
run_meta = tf.compat.v1.RunMetadata()
opts = tf.compat.v1.profiler.ProfileOptionBuilder.float_operation()
if output_path is not None:
opts['output'] = f'file:outfile={output_path}'
else:
opts['output'] = 'none'
flops = tf.compat.v1.profiler.profile(
graph=frozen_func.graph, run_meta=run_meta, options=opts)
return flops.total_float_ops
except Exception as e: # pylint: disable=broad-except
logging.info(
'Failed to count model FLOPs with error %s, because the build() '
'methods in keras layers were not called. This is probably because '
'the model was not feed any input, e.g., the max train step already '
'reached before this run.', e)
return None
return None
@ops.RegisterStatistics('Einsum', 'flops')
def _einsum_flops(graph, node):
"""Calculates the compute resources needed for Einsum."""
assert len(node.input) == 2
x_shape = tf.compat.v1.graph_util.tensor_shape_from_node_def_name(
graph, node.input[0])
y_shape = tf.compat.v1.graph_util.tensor_shape_from_node_def_name(
graph, node.input[1])
x_shape.assert_is_fully_defined()
y_shape.assert_is_fully_defined()
x_shape = x_shape.as_list()
y_shape = y_shape.as_list()
equation = str(node.attr['equation'])
equation = (
equation.replace('s:', '')
.replace('"', '')
.replace(' ', '')
.replace('\n', '')
)
x_str = equation.split(',')[0]
y_r_str = equation.split(',')[1]
y_str = y_r_str.split('->')[0]
r_str = y_r_str.split('->')[1]
shape_dic = {}
contracted = set()
for indice in x_str + y_str:
if indice in x_str:
indice_dim = x_shape[x_str.find(indice)]
elif indice in y_str:
indice_dim = y_shape[y_str.find(indice)]
else:
raise ValueError('indice {} not found in inputs'.format(indice))
shape_dic[indice] = indice_dim
if indice not in r_str:
contracted.add(indice)
madds = np.prod([shape_dic[indice] for indice in r_str]) * (
np.prod([shape_dic[indice] for indice in contracted]))
flops = 2 * madds
return ops.OpStats('flops', flops)