File size: 3,792 Bytes
cf2a15a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
# Copyright 2018 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""TensorBoard helper routine for TF op evaluator.

Requires TensorFlow.
"""


import threading


class PersistentOpEvaluator:
    """Evaluate a fixed TensorFlow graph repeatedly, safely, efficiently.

    Extend this class to create a particular kind of op evaluator, like an
    image encoder. In `initialize_graph`, create an appropriate TensorFlow
    graph with placeholder inputs. In `run`, evaluate this graph and
    return its result. This class will manage a singleton graph and
    session to preserve memory usage, and will ensure that this graph and
    session do not interfere with other concurrent sessions.

    A subclass of this class offers a threadsafe, highly parallel Python
    entry point for evaluating a particular TensorFlow graph.

    Example usage:

        class FluxCapacitanceEvaluator(PersistentOpEvaluator):
          \"\"\"Compute the flux capacitance required for a system.

          Arguments:
            x: Available power input, as a `float`, in jigawatts.

          Returns:
            A `float`, in nanofarads.
          \"\"\"

          def initialize_graph(self):
            self._placeholder = tf.placeholder(some_dtype)
            self._op = some_op(self._placeholder)

          def run(self, x):
            return self._op.eval(feed_dict: {self._placeholder: x})

        evaluate_flux_capacitance = FluxCapacitanceEvaluator()

        for x in xs:
          evaluate_flux_capacitance(x)
    """

    def __init__(self):
        super().__init__()
        self._session = None
        self._initialization_lock = threading.Lock()

    def _lazily_initialize(self):
        """Initialize the graph and session, if this has not yet been done."""
        # TODO(nickfelt): remove on-demand imports once dep situation is fixed.
        import tensorflow.compat.v1 as tf

        with self._initialization_lock:
            if self._session:
                return
            graph = tf.Graph()
            with graph.as_default():
                self.initialize_graph()
            # Don't reserve GPU because libpng can't run on GPU.
            config = tf.ConfigProto(device_count={"GPU": 0})
            self._session = tf.Session(graph=graph, config=config)

    def initialize_graph(self):
        """Create the TensorFlow graph needed to compute this operation.

        This should write ops to the default graph and return `None`.
        """
        raise NotImplementedError(
            'Subclasses must implement "initialize_graph".'
        )

    def run(self, *args, **kwargs):
        """Evaluate the ops with the given input.

        When this function is called, the default session will have the
        graph defined by a previous call to `initialize_graph`. This
        function should evaluate any ops necessary to compute the result
        of the query for the given *args and **kwargs, likely returning
        the result of a call to `some_op.eval(...)`.
        """
        raise NotImplementedError('Subclasses must implement "run".')

    def __call__(self, *args, **kwargs):
        self._lazily_initialize()
        with self._session.as_default():
            return self.run(*args, **kwargs)