|
|
|
|
|
|
|
|
|
|
|
""" |
|
FIXME |
|
This file is inherited from last version. |
|
|
|
I expect it can work with a few modifications to incorporate with the latest API, but it hasn't |
|
been tested and I'm not sure. |
|
""" |
|
|
|
from ..graph_v2 import IllegalGraphError, Cell, Edge, Graph, Node |
|
from ..operations_tf import Operation |
|
from ..type_utils import * |
|
|
|
|
|
def graph_to_tensorflow_script(graph: Graph) -> str: |
|
graphs = [graph_to_tensorflow_model(name, cell) for name, cell in graph.cell_templates.items()] |
|
return _TensorFlowScriptTemplate.format('\n\n'.join(graphs)).strip() |
|
|
|
|
|
def _sort_incoming_edges(node: Node) -> List[Edge]: |
|
edges = [edge for edge in node.graph.edges if edge.tail is node] |
|
if not edges: |
|
return [] |
|
if all(edge.tail_idx is None for edge in edges): |
|
return edges |
|
if all(isinstance(edge.tail_idx, int) for edge in edges): |
|
edges = sorted(edges, key=(lambda edge: edge.tail_idx)) |
|
if [edge.tail_idx for edge in edges] == list(range(len(edges))): |
|
return edges |
|
raise IllegalGraphError(node.graph, 'Node {} has bad inputs'.format(node.name)) |
|
|
|
def _format_inputs(node: Node) -> str: |
|
edges = _sort_incoming_edges(node) |
|
inputs = [] |
|
for edge in edges: |
|
if edge.head.name == '_inputs': |
|
assert isinstance(edge.head_idx, int) |
|
if node.graph.input_names is not None: |
|
inputs.append(node.graph.input_names[edge.head_idx]) |
|
else: |
|
inputs.append('_inputs[{}]'.format(edge.head_idx)) |
|
else: |
|
if edge.head_idx is None: |
|
inputs.append('{}'.format(edge.head.name)) |
|
else: |
|
inputs.append('{}[{}]'.format(edge.head.name, edge.head_idx)) |
|
return ', '.join(inputs) |
|
|
|
|
|
def graph_to_tensorflow_model(graph_name: str, graph: Graph) -> str: |
|
nodes = graph.topo_sort() |
|
|
|
|
|
|
|
node_codes = [] |
|
for node in nodes: |
|
if isinstance(node, Cell): |
|
node_codes.append('self.{} = {}()'.format(node.name, node.template_name)) |
|
else: |
|
node_codes.append('self.{} = {}'.format(node.name, cast(Operation, node.operation).to_tensorflow_init())) |
|
|
|
edge_codes = [] |
|
|
|
for node in nodes: |
|
inputs = _format_inputs(node) |
|
edge_codes.append('{} = self.{}({})'.format(node.name, node.name, inputs)) |
|
|
|
output_code = _format_inputs(graph.output_node) |
|
if not output_code: |
|
output_code = 'None' |
|
|
|
if graph.input_names is None: |
|
input_code = '*_inputs' |
|
else: |
|
input_code = ', '.join(graph.input_names) |
|
|
|
linebreak = '\n ' |
|
return _TensorFlowModelTemplate.format( |
|
graph_name=('Graph' if graph_name == '_graph' else graph_name), |
|
inputs=input_code, |
|
outputs=output_code, |
|
nodes=linebreak.join(node_codes), |
|
edges=linebreak.join(edge_codes) |
|
) |
|
|
|
|
|
_TensorFlowScriptTemplate = ''' |
|
import tensorflow as tf |
|
import tensorflow.keras as K |
|
|
|
import sdk.custom_ops_tf as CUSTOM |
|
|
|
{} |
|
''' |
|
|
|
_TensorFlowModelTemplate = ''' |
|
class {graph_name}(K.Model): |
|
def __init__(self): |
|
super().__init__() |
|
{nodes} |
|
|
|
def call(self, {inputs}): |
|
{edges} |
|
return {outputs} |
|
''' |