|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Utility functions or classes shared between BERT benchmarks.""" |
|
|
|
from __future__ import absolute_import |
|
from __future__ import division |
|
from __future__ import print_function |
|
|
|
import time |
|
|
|
|
|
import numpy as np |
|
from absl import flags |
|
import tensorflow as tf |
|
|
|
|
|
from official.utils.flags import core as flags_core |
|
from official.benchmark.perfzero_benchmark import PerfZeroBenchmark |
|
|
|
FLAGS = flags.FLAGS |
|
|
|
|
|
class BenchmarkTimerCallback(tf.keras.callbacks.Callback): |
|
"""Callback that records time it takes to run each batch.""" |
|
|
|
def __init__(self, num_batches_to_skip=10): |
|
super(BenchmarkTimerCallback, self).__init__() |
|
self.batch_start_times = {} |
|
self.batch_stop_times = {} |
|
|
|
def on_batch_begin(self, batch, logs=None): |
|
self.batch_start_times[batch] = time.time() |
|
|
|
def on_batch_end(self, batch, logs=None): |
|
|
|
|
|
if batch not in self.batch_start_times: |
|
batch = max(self.batch_start_times.keys()) |
|
|
|
self.batch_stop_times[batch] = time.time() |
|
|
|
def get_examples_per_sec(self, batch_size, num_batches_to_skip=1): |
|
batch_durations = [] |
|
for batch in self.batch_start_times: |
|
if batch in self.batch_stop_times and batch >= num_batches_to_skip: |
|
batch_durations.append(self.batch_stop_times[batch] - |
|
self.batch_start_times[batch]) |
|
return batch_size / np.mean(batch_durations) |
|
|
|
def get_startup_time(self, program_start_time): |
|
return self.batch_start_times[0] - program_start_time |
|
|
|
|
|
class BertBenchmarkBase(PerfZeroBenchmark): |
|
"""Base class to hold methods common to test classes.""" |
|
local_flags = None |
|
|
|
def __init__(self, output_dir=None, tpu=None, **kwargs): |
|
super(BertBenchmarkBase, self).__init__( |
|
output_dir=output_dir, tpu=tpu, **kwargs) |
|
self.num_gpus = 8 |
|
self.timer_callback = None |
|
|
|
def _setup(self): |
|
"""Sets up and resets flags before each test.""" |
|
super(BertBenchmarkBase, self)._setup() |
|
self.timer_callback = BenchmarkTimerCallback() |
|
|
|
def _report_benchmark(self, stats, wall_time_sec, min_accuracy, max_accuracy): |
|
"""Report benchmark results by writing to local protobuf file. |
|
|
|
Args: |
|
stats: dict returned from BERT models with known entries. |
|
wall_time_sec: the during of the benchmark execution in seconds |
|
min_accuracy: Minimum classification accuracy constraint to verify |
|
correctness of the model. |
|
max_accuracy: Maximum classification accuracy constraint to verify |
|
correctness of the model. |
|
""" |
|
metrics = [{ |
|
'name': 'training_loss', |
|
'value': stats['train_loss'], |
|
}] |
|
if self.timer_callback: |
|
metrics.append({ |
|
'name': |
|
'exp_per_second', |
|
'value': |
|
self.timer_callback.get_examples_per_sec(FLAGS.train_batch_size * |
|
FLAGS.steps_per_loop) |
|
}) |
|
else: |
|
metrics.append({ |
|
'name': 'exp_per_second', |
|
'value': 0.0, |
|
}) |
|
if self.timer_callback and 'start_time_sec' in stats: |
|
metrics.append({ |
|
'name': 'startup_time', |
|
'value': self.timer_callback.get_startup_time(stats['start_time_sec']) |
|
}) |
|
|
|
if 'eval_metrics' in stats: |
|
metrics.append({ |
|
'name': 'eval_accuracy', |
|
'value': stats['eval_metrics'], |
|
'min_value': min_accuracy, |
|
'max_value': max_accuracy, |
|
}) |
|
flags_str = flags_core.get_nondefault_flags_as_str() |
|
self.report_benchmark( |
|
iters=stats['total_training_steps'], |
|
wall_time=wall_time_sec, |
|
metrics=metrics, |
|
extras={'flags': flags_str}) |
|
|