|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Runs a memory usage benchmark for a Tensorflow Hub model. |
|
|
|
Loads a SavedModel and records memory usage. |
|
""" |
|
import functools |
|
import time |
|
|
|
from absl import flags |
|
import tensorflow as tf |
|
import tensorflow_hub as hub |
|
|
|
from official.benchmark.perfzero_benchmark import PerfZeroBenchmark |
|
|
|
FLAGS = flags.FLAGS |
|
|
|
|
|
class TfHubMemoryUsageBenchmark(PerfZeroBenchmark): |
|
"""A benchmark measuring memory usage for a given TF Hub SavedModel.""" |
|
|
|
def __init__(self, |
|
hub_model_handle_list=None, |
|
output_dir=None, |
|
default_flags=None, |
|
root_data_dir=None, |
|
**kwargs): |
|
super(TfHubMemoryUsageBenchmark, self).__init__( |
|
output_dir=output_dir, default_flags=default_flags, **kwargs) |
|
if hub_model_handle_list: |
|
for hub_model_handle in hub_model_handle_list.split(';'): |
|
|
|
|
|
|
|
hub_model_method_name = hub_model_handle.replace( |
|
'https://tfhub.dev', |
|
'').replace('/', '_').replace('-', '_').strip('_') |
|
setattr( |
|
self, 'benchmark_' + hub_model_method_name, |
|
functools.partial(self.benchmark_memory_usage, hub_model_handle)) |
|
|
|
def benchmark_memory_usage( |
|
self, hub_model_handle='https://tfhub.dev/google/nnlm-en-dim128/1'): |
|
start_time_sec = time.time() |
|
self.load_model(hub_model_handle) |
|
wall_time_sec = time.time() - start_time_sec |
|
|
|
metrics = [] |
|
self.report_benchmark(iters=-1, wall_time=wall_time_sec, metrics=metrics) |
|
|
|
def load_model(self, hub_model_handle): |
|
"""Loads a TF Hub module.""" |
|
hub.load(hub_model_handle) |
|
|
|
|
|
if __name__ == '__main__': |
|
tf.test.main() |
|
|