File size: 3,519 Bytes
375a1cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
# Copyright (c) Facebook, Inc. and its affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import dis
import inspect

from dataclasses import dataclass
from typing import Union

from . import DimList

_vmap_levels = []


@dataclass
class LevelInfo:
    level: int
    alive: bool = True


class Dim:
    def __init__(self, name: str, size: Union[None, int] = None):
        self.name = name
        self._size = None
        self._vmap_level = None
        if size is not None:
            self.size = size

    def __del__(self):
        if self._vmap_level is not None:
            _vmap_active_levels[self._vmap_stack].alive = False  # noqa: F821
            while (
                not _vmap_levels[-1].alive
                and current_level() == _vmap_levels[-1].level  # noqa: F821
            ):
                _vmap_decrement_nesting()  # noqa: F821
                _vmap_levels.pop()

    @property
    def size(self):
        assert self.is_bound
        return self._size

    @size.setter
    def size(self, size: int):
        from . import DimensionBindError

        if self._size is None:
            self._size = size
            self._vmap_level = _vmap_increment_nesting(size, "same")  # noqa: F821
            self._vmap_stack = len(_vmap_levels)
            _vmap_levels.append(LevelInfo(self._vmap_level))

        elif self._size != size:
            raise DimensionBindError(
                f"Dim '{self}' previously bound to a dimension of size {self._size} cannot bind to a dimension of size {size}"
            )

    @property
    def is_bound(self):
        return self._size is not None

    def __repr__(self):
        return self.name


def extract_name(inst):
    assert inst.opname == "STORE_FAST" or inst.opname == "STORE_NAME"
    return inst.argval


_cache = {}


def dims(lists=0):
    frame = inspect.currentframe()
    assert frame is not None
    calling_frame = frame.f_back
    assert calling_frame is not None
    code, lasti = calling_frame.f_code, calling_frame.f_lasti
    key = (code, lasti)
    if key not in _cache:
        first = lasti // 2 + 1
        instructions = list(dis.get_instructions(calling_frame.f_code))
        unpack = instructions[first]

        if unpack.opname == "STORE_FAST" or unpack.opname == "STORE_NAME":
            # just a single dim, not a list
            name = unpack.argval
            ctor = Dim if lists == 0 else DimList
            _cache[key] = lambda: ctor(name=name)
        else:
            assert unpack.opname == "UNPACK_SEQUENCE"
            ndims = unpack.argval
            names = tuple(
                extract_name(instructions[first + 1 + i]) for i in range(ndims)
            )
            first_list = len(names) - lists
            _cache[key] = lambda: tuple(
                Dim(n) if i < first_list else DimList(name=n)
                for i, n in enumerate(names)
            )
    return _cache[key]()


def _dim_set(positional, arg):
    def convert(a):
        if isinstance(a, Dim):
            return a
        else:
            assert isinstance(a, int)
            return positional[a]

    if arg is None:
        return positional
    elif not isinstance(arg, (Dim, int)):
        return tuple(convert(a) for a in arg)
    else:
        return (convert(arg),)