Spaces:
Running
Running
# Copyright (c) Microsoft Corporation. | |
# Licensed under the MIT license. | |
from typing import (Any, Dict, List) | |
from . import debug_configs | |
__all__ = ['Operation', 'Cell'] | |
def _convert_name(name: str) -> str: | |
""" | |
Convert the names using separator '.' to valid variable name in code | |
""" | |
return name.replace('.', '__') | |
class Operation: | |
""" | |
Calculation logic of a graph node. | |
The constructor is private. Use `Operation.new()` to create operation object. | |
`Operation` is a naive record. | |
Do not "mutate" its attributes or store information relate to specific node. | |
All complex logic should be implemented in `Node` class. | |
Attributes | |
---------- | |
type | |
Operation type name (e.g. Conv2D). | |
If it starts with underscore, the "operation" is a special one (e.g. subgraph, input/output). | |
parameters | |
Arbitrary key-value parameters (e.g. kernel_size). | |
""" | |
def __init__(self, type_name: str, parameters: Dict[str, Any], _internal: bool = False): | |
assert _internal, '`Operation()` is private, use `Operation.new()` instead' | |
self.type: str = type_name | |
self.parameters: Dict[str, Any] = parameters | |
def to_init_code(self, field: str) -> str: | |
raise NotImplementedError() | |
def to_forward_code(self, field: str, output: str, inputs: List[str]) -> str: | |
raise NotImplementedError() | |
def _to_class_name(self) -> str: | |
raise NotImplementedError() | |
def __bool__(self) -> bool: | |
return True | |
def new(type_name: str, parameters: Dict[str, Any] = {}, cell_name: str = None) -> 'Operation': | |
if type_name == '_cell': | |
# NOTE: cell_name is the same as its Node's name, when the cell is wrapped within the node | |
return Cell(cell_name, parameters) | |
else: | |
if debug_configs.framework.lower() in ('torch', 'pytorch'): | |
from .operation_def import torch_op_def # pylint: disable=unused-import | |
cls = PyTorchOperation._find_subclass(type_name) | |
elif debug_configs.framework.lower() in ('tf', 'tensorflow'): | |
from .operation_def import tf_op_def # pylint: disable=unused-import | |
cls = TensorFlowOperation._find_subclass(type_name) | |
else: | |
raise ValueError(f'Unsupported framework: {debug_configs.framework}') | |
return cls(type_name, parameters, _internal=True) | |
def _find_subclass(cls, subclass_name): | |
for subclass in cls.__subclasses__(): | |
if subclass.__name__ == subclass_name: | |
return subclass | |
return cls | |
def __repr__(self): | |
type_name = type(self).__name__ | |
args = [f'{key}={repr(value)}' for key, value in self.parameters.items()] | |
if type_name != self.type: | |
args = [f'type="{self.type}"'] + args | |
return f'{type_name}({", ".join(args)})' | |
def __eq__(self, other): | |
return type(other) is type(self) and other.type == self.type and other.parameters == self.parameters | |
class PyTorchOperation(Operation): | |
def _find_subclass(cls, subclass_name): | |
if cls.to_class_name(subclass_name) is not None: | |
subclass_name = 'ModuleOperator' | |
if cls.is_functional(subclass_name): | |
subclass_name = 'FunctionalOperator' | |
for subclass in cls.__subclasses__(): | |
if hasattr(subclass, '_ori_type_name') and \ | |
subclass_name in subclass._ori_type_name: | |
return subclass | |
return cls | |
def to_class_name(cls, type_name) -> str: | |
if type_name.startswith('__torch__.'): | |
return type_name[len('__torch__.'):] | |
elif type_name.startswith('__mutated__.'): | |
return type_name[len('__mutated__.'):] | |
else: | |
return None | |
def is_functional(cls, type_name) -> bool: | |
return type_name.startswith('Function.') | |
def _to_class_name(self) -> str: | |
if self.type.startswith('__torch__.'): | |
return self.type[len('__torch__.'):] | |
elif self.type.startswith('__mutated__.'): | |
return self.type[len('__mutated__.'):] | |
else: | |
return None | |
def get_import_pkg(self) -> str: | |
if self.type.startswith('__torch__.'): | |
return self.type[len('__torch__.'):].split('.')[0] | |
elif self.type.startswith('__mutated__.'): | |
return self.type[len('__mutated__.'):].split('.')[0] | |
else: | |
return None | |
def to_init_code(self, field: str) -> str: | |
if self._to_class_name() is not None: | |
assert 'positional_args' not in self.parameters | |
kw_params = ', '.join(f'{key}={repr(value)}' for key, value in self.parameters.items()) | |
return f'self.{field} = {self._to_class_name()}({kw_params})' | |
return None | |
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str: | |
""" | |
Parameters | |
---------- | |
field : str | |
the name of member submodule | |
output : str | |
the output name (lvalue) of this line of code | |
inputs : List[str] | |
variables used in this line of code | |
inputs_value : List[Any] | |
some variables are actually constant, their real values are recorded in ```inputs_value```. | |
if not constant, we simply put None at the corresponding index | |
Returns | |
------- | |
str | |
generated code line | |
""" | |
if self.type == 'aten::slice': | |
raise RuntimeError('not supposed to have aten::slice operation') | |
else: | |
raise RuntimeError(f'unsupported operation type: {self.type} ? {self._to_class_name()}') | |
class TensorFlowOperation(Operation): | |
def _to_class_name(self) -> str: | |
return 'K.layers.' + self.type | |
class Cell(PyTorchOperation): | |
""" | |
TODO: this is pytorch cell | |
An operation reference to a subgraph. | |
Example code: | |
``` | |
def __init__(...): | |
... | |
self.cell = CustomCell(...) | |
self.relu = K.layers.ReLU() | |
... | |
def forward(...): | |
... | |
x = self.cell(x) | |
... | |
``` | |
In above example, node `self.cell`'s operation is `Cell(cell_name='CustomCell')`. | |
For comparison, `self.relu`'s operation is `Operation(type='ReLU')`. | |
TODO: parameters of subgraph (see `Node` class) | |
Attributes | |
---------- | |
type | |
Always "_cell". | |
parameters | |
A dict with only one item; the key is "cell" and the value is cell's name. | |
framework | |
No real usage. Exists for compatibility with base class. | |
""" | |
def __init__(self, cell_name: str, parameters: Dict[str, Any] = {}): | |
self.type = '_cell' | |
self.cell_name = cell_name | |
self.parameters = parameters | |
def _to_class_name(self): | |
# TODO: ugly, think about how to refactor this part | |
return _convert_name(self.cell_name) | |
def to_forward_code(self, field: str, output: str, inputs: List[str], inputs_value: List[Any] = None) -> str: | |
return f'{output} = self.{field}({", ".join(inputs)})' | |
class _IOPseudoOperation(Operation): | |
""" | |
This is the pseudo operation used by I/O nodes. | |
The benefit is that users no longer need to verify `Node.operation is not None`, | |
especially in static type checking. | |
""" | |
def __init__(self, type_name: str, io_names: List = None): | |
assert type_name.startswith('_') | |
super(_IOPseudoOperation, self).__init__(type_name, {}, True) | |
self.io_names = io_names | |
def to_init_code(self, field: str) -> str: | |
raise ValueError(f'Cannot generate code for pseudo operation "{self.type}"') | |
def to_forward_code(self, field: str, output: str, inputs: List[str]) -> str: | |
raise ValueError(f'Cannot generate code for pseudo operation "{self.type}"') | |
def __bool__(self) -> bool: | |
return False | |