Spaces:
Running
Running
# Copyright (c) Microsoft Corporation. | |
# Licensed under the MIT license. | |
# pylint: skip-file | |
""" | |
FIXME | |
This file is inherited from last version. | |
I expect it can work with a few modifications to incorporate with the latest API, but it hasn't | |
been tested and I'm not sure. | |
""" | |
from ..graph_v2 import IllegalGraphError, Cell, Edge, Graph, Node | |
from ..operations_tf import Operation | |
from ..type_utils import * | |
def graph_to_tensorflow_script(graph: Graph) -> str: | |
graphs = [graph_to_tensorflow_model(name, cell) for name, cell in graph.cell_templates.items()] | |
return _TensorFlowScriptTemplate.format('\n\n'.join(graphs)).strip() | |
def _sort_incoming_edges(node: Node) -> List[Edge]: | |
edges = [edge for edge in node.graph.edges if edge.tail is node] | |
if not edges: | |
return [] | |
if all(edge.tail_idx is None for edge in edges): | |
return edges | |
if all(isinstance(edge.tail_idx, int) for edge in edges): | |
edges = sorted(edges, key=(lambda edge: edge.tail_idx)) | |
if [edge.tail_idx for edge in edges] == list(range(len(edges))): | |
return edges | |
raise IllegalGraphError(node.graph, 'Node {} has bad inputs'.format(node.name)) | |
def _format_inputs(node: Node) -> str: | |
edges = _sort_incoming_edges(node) | |
inputs = [] | |
for edge in edges: | |
if edge.head.name == '_inputs': | |
assert isinstance(edge.head_idx, int) | |
if node.graph.input_names is not None: | |
inputs.append(node.graph.input_names[edge.head_idx]) | |
else: | |
inputs.append('_inputs[{}]'.format(edge.head_idx)) | |
else: | |
if edge.head_idx is None: | |
inputs.append('{}'.format(edge.head.name)) | |
else: | |
inputs.append('{}[{}]'.format(edge.head.name, edge.head_idx)) | |
return ', '.join(inputs) | |
def graph_to_tensorflow_model(graph_name: str, graph: Graph) -> str: | |
nodes = graph.topo_sort() | |
# handle module node and function node differently | |
# only need to generate code for module here | |
node_codes = [] | |
for node in nodes: | |
if isinstance(node, Cell): | |
node_codes.append('self.{} = {}()'.format(node.name, node.template_name)) | |
else: | |
node_codes.append('self.{} = {}'.format(node.name, cast(Operation, node.operation).to_tensorflow_init())) | |
edge_codes = [] | |
for node in nodes: | |
inputs = _format_inputs(node) | |
edge_codes.append('{} = self.{}({})'.format(node.name, node.name, inputs)) | |
output_code = _format_inputs(graph.output_node) | |
if not output_code: | |
output_code = 'None' | |
if graph.input_names is None: | |
input_code = '*_inputs' | |
else: | |
input_code = ', '.join(graph.input_names) | |
linebreak = '\n ' | |
return _TensorFlowModelTemplate.format( | |
graph_name=('Graph' if graph_name == '_graph' else graph_name), | |
inputs=input_code, | |
outputs=output_code, | |
nodes=linebreak.join(node_codes), | |
edges=linebreak.join(edge_codes) | |
) | |
_TensorFlowScriptTemplate = ''' | |
import tensorflow as tf | |
import tensorflow.keras as K | |
import sdk.custom_ops_tf as CUSTOM | |
{} | |
''' | |
_TensorFlowModelTemplate = ''' | |
class {graph_name}(K.Model): | |
def __init__(self): | |
super().__init__() | |
{nodes} | |
def call(self, {inputs}): | |
{edges} | |
return {outputs} | |
''' |