Kano001's picture
Upload 527 files
cf2a15a verified
raw
history blame
3.79 kB
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""TensorBoard helper routine for TF op evaluator.
Requires TensorFlow.
"""
import threading
class PersistentOpEvaluator:
"""Evaluate a fixed TensorFlow graph repeatedly, safely, efficiently.
Extend this class to create a particular kind of op evaluator, like an
image encoder. In `initialize_graph`, create an appropriate TensorFlow
graph with placeholder inputs. In `run`, evaluate this graph and
return its result. This class will manage a singleton graph and
session to preserve memory usage, and will ensure that this graph and
session do not interfere with other concurrent sessions.
A subclass of this class offers a threadsafe, highly parallel Python
entry point for evaluating a particular TensorFlow graph.
Example usage:
class FluxCapacitanceEvaluator(PersistentOpEvaluator):
\"\"\"Compute the flux capacitance required for a system.
Arguments:
x: Available power input, as a `float`, in jigawatts.
Returns:
A `float`, in nanofarads.
\"\"\"
def initialize_graph(self):
self._placeholder = tf.placeholder(some_dtype)
self._op = some_op(self._placeholder)
def run(self, x):
return self._op.eval(feed_dict: {self._placeholder: x})
evaluate_flux_capacitance = FluxCapacitanceEvaluator()
for x in xs:
evaluate_flux_capacitance(x)
"""
def __init__(self):
super().__init__()
self._session = None
self._initialization_lock = threading.Lock()
def _lazily_initialize(self):
"""Initialize the graph and session, if this has not yet been done."""
# TODO(nickfelt): remove on-demand imports once dep situation is fixed.
import tensorflow.compat.v1 as tf
with self._initialization_lock:
if self._session:
return
graph = tf.Graph()
with graph.as_default():
self.initialize_graph()
# Don't reserve GPU because libpng can't run on GPU.
config = tf.ConfigProto(device_count={"GPU": 0})
self._session = tf.Session(graph=graph, config=config)
def initialize_graph(self):
"""Create the TensorFlow graph needed to compute this operation.
This should write ops to the default graph and return `None`.
"""
raise NotImplementedError(
'Subclasses must implement "initialize_graph".'
)
def run(self, *args, **kwargs):
"""Evaluate the ops with the given input.
When this function is called, the default session will have the
graph defined by a previous call to `initialize_graph`. This
function should evaluate any ops necessary to compute the result
of the query for the given *args and **kwargs, likely returning
the result of a call to `some_op.eval(...)`.
"""
raise NotImplementedError('Subclasses must implement "run".')
def __call__(self, *args, **kwargs):
self._lazily_initialize()
with self._session.as_default():
return self.run(*args, **kwargs)