Spaces:
Running
Running
# Copyright 2018 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""TensorBoard helper routine for TF op evaluator. | |
Requires TensorFlow. | |
""" | |
import threading | |
class PersistentOpEvaluator: | |
"""Evaluate a fixed TensorFlow graph repeatedly, safely, efficiently. | |
Extend this class to create a particular kind of op evaluator, like an | |
image encoder. In `initialize_graph`, create an appropriate TensorFlow | |
graph with placeholder inputs. In `run`, evaluate this graph and | |
return its result. This class will manage a singleton graph and | |
session to preserve memory usage, and will ensure that this graph and | |
session do not interfere with other concurrent sessions. | |
A subclass of this class offers a threadsafe, highly parallel Python | |
entry point for evaluating a particular TensorFlow graph. | |
Example usage: | |
class FluxCapacitanceEvaluator(PersistentOpEvaluator): | |
\"\"\"Compute the flux capacitance required for a system. | |
Arguments: | |
x: Available power input, as a `float`, in jigawatts. | |
Returns: | |
A `float`, in nanofarads. | |
\"\"\" | |
def initialize_graph(self): | |
self._placeholder = tf.placeholder(some_dtype) | |
self._op = some_op(self._placeholder) | |
def run(self, x): | |
return self._op.eval(feed_dict: {self._placeholder: x}) | |
evaluate_flux_capacitance = FluxCapacitanceEvaluator() | |
for x in xs: | |
evaluate_flux_capacitance(x) | |
""" | |
def __init__(self): | |
super().__init__() | |
self._session = None | |
self._initialization_lock = threading.Lock() | |
def _lazily_initialize(self): | |
"""Initialize the graph and session, if this has not yet been done.""" | |
# TODO(nickfelt): remove on-demand imports once dep situation is fixed. | |
import tensorflow.compat.v1 as tf | |
with self._initialization_lock: | |
if self._session: | |
return | |
graph = tf.Graph() | |
with graph.as_default(): | |
self.initialize_graph() | |
# Don't reserve GPU because libpng can't run on GPU. | |
config = tf.ConfigProto(device_count={"GPU": 0}) | |
self._session = tf.Session(graph=graph, config=config) | |
def initialize_graph(self): | |
"""Create the TensorFlow graph needed to compute this operation. | |
This should write ops to the default graph and return `None`. | |
""" | |
raise NotImplementedError( | |
'Subclasses must implement "initialize_graph".' | |
) | |
def run(self, *args, **kwargs): | |
"""Evaluate the ops with the given input. | |
When this function is called, the default session will have the | |
graph defined by a previous call to `initialize_graph`. This | |
function should evaluate any ops necessary to compute the result | |
of the query for the given *args and **kwargs, likely returning | |
the result of a call to `some_op.eval(...)`. | |
""" | |
raise NotImplementedError('Subclasses must implement "run".') | |
def __call__(self, *args, **kwargs): | |
self._lazily_initialize() | |
with self._session.as_default(): | |
return self.run(*args, **kwargs) | |