File size: 3,354 Bytes
b84549f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

# pylint: skip-file

"""
FIXME
This file is inherited from last version.

I expect it can work with a few modifications to incorporate with the latest API, but it hasn't
been tested and I'm not sure.
"""

from ..graph_v2 import IllegalGraphError, Cell, Edge, Graph, Node
from ..operations_tf import Operation
from ..type_utils import *


def graph_to_tensorflow_script(graph: Graph) -> str:
    graphs = [graph_to_tensorflow_model(name, cell) for name, cell in graph.cell_templates.items()]
    return _TensorFlowScriptTemplate.format('\n\n'.join(graphs)).strip()


def _sort_incoming_edges(node: Node) -> List[Edge]:
    edges = [edge for edge in node.graph.edges if edge.tail is node]
    if not edges:
        return []
    if all(edge.tail_idx is None for edge in edges):
        return edges
    if all(isinstance(edge.tail_idx, int) for edge in edges):
        edges = sorted(edges, key=(lambda edge: edge.tail_idx))
        if [edge.tail_idx for edge in edges] == list(range(len(edges))):
            return edges
    raise IllegalGraphError(node.graph, 'Node {} has bad inputs'.format(node.name))

def _format_inputs(node: Node) -> str:
    edges = _sort_incoming_edges(node)
    inputs = []
    for edge in edges:
        if edge.head.name == '_inputs':
            assert isinstance(edge.head_idx, int)
            if node.graph.input_names is not None:
                inputs.append(node.graph.input_names[edge.head_idx])
            else:
                inputs.append('_inputs[{}]'.format(edge.head_idx))
        else:
            if edge.head_idx is None:
                inputs.append('{}'.format(edge.head.name))
            else:
                inputs.append('{}[{}]'.format(edge.head.name, edge.head_idx))
    return ', '.join(inputs)


def graph_to_tensorflow_model(graph_name: str, graph: Graph) -> str:
    nodes = graph.topo_sort()

    # handle module node and function node differently
    # only need to generate code for module here
    node_codes = []
    for node in nodes:
        if isinstance(node, Cell):
            node_codes.append('self.{} = {}()'.format(node.name, node.template_name))
        else:
            node_codes.append('self.{} = {}'.format(node.name, cast(Operation, node.operation).to_tensorflow_init()))

    edge_codes = []

    for node in nodes:
        inputs = _format_inputs(node)
        edge_codes.append('{} = self.{}({})'.format(node.name, node.name, inputs))

    output_code = _format_inputs(graph.output_node)
    if not output_code:
        output_code = 'None'

    if graph.input_names is None:
        input_code = '*_inputs'
    else:
        input_code = ', '.join(graph.input_names)

    linebreak = '\n        '
    return _TensorFlowModelTemplate.format(
        graph_name=('Graph' if graph_name == '_graph' else graph_name),
        inputs=input_code,
        outputs=output_code,
        nodes=linebreak.join(node_codes),
        edges=linebreak.join(edge_codes)
    )


_TensorFlowScriptTemplate = '''
import tensorflow as tf
import tensorflow.keras as K

import sdk.custom_ops_tf as CUSTOM

{}
'''

_TensorFlowModelTemplate = '''
class {graph_name}(K.Model):
    def __init__(self):
        super().__init__()
        {nodes}

    def call(self, {inputs}):
        {edges}
        return {outputs}
'''