|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Helper functions for the Keras implementations of models.""" |
|
|
|
from __future__ import absolute_import |
|
from __future__ import division |
|
from __future__ import print_function |
|
|
|
import multiprocessing |
|
import os |
|
import time |
|
|
|
from absl import logging |
|
import tensorflow as tf |
|
|
|
|
|
class BatchTimestamp(object): |
|
"""A structure to store batch time stamp.""" |
|
|
|
def __init__(self, batch_index, timestamp): |
|
self.batch_index = batch_index |
|
self.timestamp = timestamp |
|
|
|
def __repr__(self): |
|
return "'BatchTimestamp<batch_index: {}, timestamp: {}>'".format( |
|
self.batch_index, self.timestamp) |
|
|
|
|
|
class TimeHistory(tf.keras.callbacks.Callback): |
|
"""Callback for Keras models.""" |
|
|
|
def __init__(self, batch_size, log_steps, initial_step=0, logdir=None): |
|
"""Callback for logging performance. |
|
|
|
Args: |
|
batch_size: Total batch size. |
|
log_steps: Interval of steps between logging of batch level stats. |
|
initial_step: Optional, initial step. |
|
logdir: Optional directory to write TensorBoard summaries. |
|
""" |
|
|
|
|
|
self.batch_size = batch_size |
|
super(TimeHistory, self).__init__() |
|
self.log_steps = log_steps |
|
self.last_log_step = initial_step |
|
self.steps_before_epoch = initial_step |
|
self.steps_in_epoch = 0 |
|
self.start_time = None |
|
|
|
if logdir: |
|
self.summary_writer = tf.summary.create_file_writer(logdir) |
|
else: |
|
self.summary_writer = None |
|
|
|
|
|
self.timestamp_log = [] |
|
|
|
|
|
self.epoch_runtime_log = [] |
|
|
|
@property |
|
def global_steps(self): |
|
"""The current 1-indexed global step.""" |
|
return self.steps_before_epoch + self.steps_in_epoch |
|
|
|
@property |
|
def average_steps_per_second(self): |
|
"""The average training steps per second across all epochs.""" |
|
return self.global_steps / sum(self.epoch_runtime_log) |
|
|
|
@property |
|
def average_examples_per_second(self): |
|
"""The average number of training examples per second across all epochs.""" |
|
return self.average_steps_per_second * self.batch_size |
|
|
|
def get_examples_per_sec(self, warmup=1): |
|
"""Calculates examples/sec through timestamp_log and skip warmup period.""" |
|
|
|
|
|
time_log = self.timestamp_log |
|
seconds = time_log[-1].timestamp - time_log[warmup].timestamp |
|
steps = time_log[-1].batch_index - time_log[warmup].batch_index |
|
return self.batch_size * steps / seconds |
|
|
|
def get_startup_time(self, start_time_sec): |
|
return self.timestamp_log[0].timestamp - start_time_sec |
|
|
|
def on_train_end(self, logs=None): |
|
self.train_finish_time = time.time() |
|
|
|
if self.summary_writer: |
|
self.summary_writer.flush() |
|
|
|
def on_epoch_begin(self, epoch, logs=None): |
|
self.epoch_start = time.time() |
|
|
|
def on_batch_begin(self, batch, logs=None): |
|
if not self.start_time: |
|
self.start_time = time.time() |
|
|
|
|
|
if not self.timestamp_log: |
|
self.timestamp_log.append(BatchTimestamp(self.global_steps, |
|
self.start_time)) |
|
|
|
def on_batch_end(self, batch, logs=None): |
|
"""Records elapse time of the batch and calculates examples per second.""" |
|
self.steps_in_epoch = batch + 1 |
|
steps_since_last_log = self.global_steps - self.last_log_step |
|
if steps_since_last_log >= self.log_steps: |
|
now = time.time() |
|
elapsed_time = now - self.start_time |
|
steps_per_second = steps_since_last_log / elapsed_time |
|
examples_per_second = steps_per_second * self.batch_size |
|
|
|
self.timestamp_log.append(BatchTimestamp(self.global_steps, now)) |
|
logging.info( |
|
'TimeHistory: %.2f seconds, %.2f examples/second between steps %d ' |
|
'and %d', elapsed_time, examples_per_second, self.last_log_step, |
|
self.global_steps) |
|
|
|
if self.summary_writer: |
|
with self.summary_writer.as_default(): |
|
tf.summary.scalar('steps_per_second', steps_per_second, |
|
self.global_steps) |
|
tf.summary.scalar('examples_per_second', examples_per_second, |
|
self.global_steps) |
|
|
|
self.last_log_step = self.global_steps |
|
self.start_time = None |
|
|
|
def on_epoch_end(self, epoch, logs=None): |
|
epoch_run_time = time.time() - self.epoch_start |
|
self.epoch_runtime_log.append(epoch_run_time) |
|
|
|
self.steps_before_epoch += self.steps_in_epoch |
|
self.steps_in_epoch = 0 |
|
|
|
|
|
class SimpleCheckpoint(tf.keras.callbacks.Callback): |
|
"""Keras callback to save tf.train.Checkpoints.""" |
|
|
|
def __init__(self, checkpoint_manager): |
|
super(SimpleCheckpoint, self).__init__() |
|
self.checkpoint_manager = checkpoint_manager |
|
|
|
def on_epoch_end(self, epoch, logs=None): |
|
step_counter = self.checkpoint_manager._step_counter.numpy() |
|
self.checkpoint_manager.save(checkpoint_number=step_counter) |
|
|
|
|
|
def set_session_config(enable_xla=False): |
|
"""Sets the session config.""" |
|
if enable_xla: |
|
tf.config.optimizer.set_jit(True) |
|
|
|
|
|
set_config_v2 = set_session_config |
|
|
|
|
|
def set_gpu_thread_mode_and_count(gpu_thread_mode, |
|
datasets_num_private_threads, |
|
num_gpus, per_gpu_thread_count): |
|
"""Set GPU thread mode and count, and adjust dataset threads count.""" |
|
cpu_count = multiprocessing.cpu_count() |
|
logging.info('Logical CPU cores: %s', cpu_count) |
|
|
|
|
|
per_gpu_thread_count = per_gpu_thread_count or 2 |
|
os.environ['TF_GPU_THREAD_MODE'] = gpu_thread_mode |
|
os.environ['TF_GPU_THREAD_COUNT'] = str(per_gpu_thread_count) |
|
logging.info('TF_GPU_THREAD_COUNT: %s', |
|
os.environ['TF_GPU_THREAD_COUNT']) |
|
logging.info('TF_GPU_THREAD_MODE: %s', |
|
os.environ['TF_GPU_THREAD_MODE']) |
|
|
|
|
|
|
|
total_gpu_thread_count = per_gpu_thread_count * num_gpus |
|
num_runtime_threads = num_gpus |
|
if not datasets_num_private_threads: |
|
datasets_num_private_threads = min( |
|
cpu_count - total_gpu_thread_count - num_runtime_threads, |
|
num_gpus * 8) |
|
logging.info('Set datasets_num_private_threads to %s', |
|
datasets_num_private_threads) |
|
|