|
|
|
|
|
|
|
import logging |
|
from typing import List, Tuple, Any |
|
|
|
from ..graph import IllegalGraphError, Edge, Graph, Node, Model |
|
|
|
_logger = logging.getLogger(__name__) |
|
|
|
|
|
def model_to_pytorch_script(model: Model, placement=None) -> str: |
|
graphs = [] |
|
total_pkgs = set() |
|
for name, cell in model.graphs.items(): |
|
import_pkgs, graph_code = graph_to_pytorch_model(name, cell, placement=placement) |
|
graphs.append(graph_code) |
|
total_pkgs.update(import_pkgs) |
|
pkgs_code = '\n'.join(['import {}'.format(pkg) for pkg in total_pkgs]) |
|
return _PyTorchScriptTemplate.format(pkgs_code, '\n\n'.join(graphs)).strip() |
|
|
|
|
|
def _sorted_incoming_edges(node: Node) -> List[Edge]: |
|
edges = [edge for edge in node.graph.edges if edge.tail is node] |
|
_logger.debug('sorted_incoming_edges: %s', str(edges)) |
|
if not edges: |
|
return [] |
|
_logger.debug('all tail_slots are None: %s', str([edge.tail_slot for edge in edges])) |
|
if all(edge.tail_slot is None for edge in edges): |
|
return edges |
|
if all(isinstance(edge.tail_slot, int) for edge in edges): |
|
edges = sorted(edges, key=(lambda edge: edge.tail_slot)) |
|
if [edge.tail_slot for edge in edges] == list(range(len(edges))): |
|
return edges |
|
raise IllegalGraphError(node.graph, 'Node {} has bad inputs'.format(node.name)) |
|
|
|
|
|
def _format_inputs(node: Node) -> Tuple[List[str], List[Any]]: |
|
""" |
|
Format the inputs of a given node |
|
|
|
Parameters |
|
---------- |
|
node : Node |
|
a graph node, get and format its inputs |
|
|
|
Returns |
|
------- |
|
list |
|
the list of input names |
|
list |
|
the list of input values, if an input is simple type, record its value, |
|
otherwise the value is None |
|
""" |
|
edges = _sorted_incoming_edges(node) |
|
inputs = [] |
|
inputs_value = [] |
|
for edge in edges: |
|
if edge.head.name == '_inputs': |
|
assert isinstance(edge.head_slot, int) |
|
if edge.head.operation.io_names is not None: |
|
|
|
inputs.append(edge.head.operation.io_names[edge.head_slot]) |
|
else: |
|
|
|
inputs.append('_inputs[{}]'.format(edge.head_slot)) |
|
inputs_value.append(None) |
|
else: |
|
if edge.head_slot is None: |
|
|
|
inputs.append('{}'.format(edge.head.name)) |
|
if edge.head.operation.type in ('prim::Constant', 'prim::GetAttr') and \ |
|
'value' in edge.head.operation.parameters: |
|
inputs_value.append(edge.head.operation.parameters['value']) |
|
else: |
|
inputs_value.append(None) |
|
else: |
|
|
|
inputs.append('{}[{}]'.format(edge.head.name, edge.head_slot)) |
|
inputs_value.append(None) |
|
return inputs, inputs_value |
|
|
|
|
|
def _remove_prefix(names, graph_name): |
|
""" |
|
variables name (full name space) is too long, |
|
shorten the name by removing the prefix ```graph_name``` |
|
""" |
|
if isinstance(names, list): |
|
converted_names = [] |
|
for name in names: |
|
if name.startswith(graph_name): |
|
converted_names.append(name[len(graph_name):]) |
|
else: |
|
converted_names.append(name) |
|
return converted_names |
|
else: |
|
return names[len(graph_name):] if names.startswith(graph_name) else names |
|
|
|
|
|
def graph_to_pytorch_model(graph_name: str, graph: Graph, placement=None) -> str: |
|
nodes = graph.topo_sort() |
|
|
|
|
|
|
|
import_pkgs = set() |
|
node_codes = [] |
|
for node in nodes: |
|
if node.operation: |
|
if node.operation.type == 'shared': |
|
continue |
|
pkg_name = node.operation.get_import_pkg() |
|
if pkg_name is not None: |
|
import_pkgs.add(pkg_name) |
|
node_code = node.operation.to_init_code(_remove_prefix(node.name, graph_name)) |
|
if node_code is not None: |
|
if placement and node in placement and len(node_code) > 0: |
|
node_codes.append(f"{node_code}.to('{placement[node].device}')") |
|
else: |
|
node_codes.append(node_code) |
|
|
|
if graph.input_node.operation.io_names is None: |
|
input_code = '*_inputs' |
|
else: |
|
for name in graph.input_node.operation.io_names: |
|
assert not name.startswith(graph_name) |
|
input_code = ', '.join(graph.input_node.operation.io_names) |
|
|
|
edge_codes = [] |
|
sorted_nodes = graph.topo_sort() |
|
for node in sorted_nodes: |
|
if node.operation: |
|
inputs, inputs_value = _format_inputs(node) |
|
inputs = _remove_prefix(inputs, graph_name) |
|
node_name = _remove_prefix(node.name, graph_name) |
|
submodule_name = node_name |
|
if node.operation.type == 'shared': |
|
submodule_name = _remove_prefix(node.operation.parameters['reference'], graph_name) |
|
edge_codes.append(node.operation.to_forward_code(submodule_name, node_name, inputs, inputs_value)) |
|
|
|
output_names, _ = _format_inputs(graph.output_node) |
|
output_names = _remove_prefix(output_names, graph_name) |
|
if not output_names: |
|
raise RuntimeError('"forward" function should have return value(s): {}, {}, {}'.format(output_names, graph_name, graph.output_node)) |
|
output_code = ', '.join(output_names) |
|
|
|
linebreak = '\n ' |
|
return import_pkgs, _PyTorchModelTemplate.format( |
|
graph_name=('Graph' if graph_name == '_graph' else graph_name), |
|
inputs=input_code, |
|
outputs=output_code, |
|
nodes=linebreak.join(node_codes), |
|
edges=linebreak.join(edge_codes) |
|
) |
|
|
|
|
|
|
|
|
|
_PyTorchScriptTemplate = ''' |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import torch.optim as optim |
|
|
|
import nni.retiarii.nn.pytorch |
|
|
|
{} |
|
|
|
{} |
|
''' |
|
|
|
_PyTorchModelTemplate = ''' |
|
class {graph_name}(nn.Module): |
|
def __init__(self): |
|
super().__init__() |
|
{nodes} |
|
|
|
def forward(self, {inputs}): |
|
{edges} |
|
return {outputs} |
|
''' |
|
|