File size: 8,043 Bytes
e45d058
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
#################################################################################################
#
# Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: BSD-3-Clause
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are met:
#
# 1. Redistributions of source code must retain the above copyright notice, this
# list of conditions and the following disclaimer.
#
# 2. Redistributions in binary form must reproduce the above copyright notice,
# this list of conditions and the following disclaimer in the documentation
# and/or other materials provided with the distribution.
#
# 3. Neither the name of the copyright holder nor the names of its
# contributors may be used to endorse or promote products derived from
# this software without specific prior written permission.
#
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#
#################################################################################################

"""

Functions for manipulating IntTuples

"""

from functools import reduce
from itertools import chain
from typing import Union
from .typing import Integer


def is_int(x):
  return isinstance(x, Integer)


def is_tuple(x):
  return isinstance(x, tuple)


def flatten(t):
  if is_tuple(t):
    if len(t) == 0:
      return ()
    else:
      return tuple(i for a in t for i in flatten(a))
  else:
    return (t,)


def signum(a):
  return bool(a > 0) - bool(a < 0)


def product(a):
  if is_tuple(a):
    return reduce(lambda val,elem : val*product(elem), a, 1)
  else:
    return a


def inner_product(a, b):
  if is_tuple(a):                      # tuple tuple
    assert len(a) == len(b)
    return sum(inner_product(x,y) for x,y in zip(a,b))
  else:                                # "int" "int"
    assert not is_tuple(b)
    return a * b


def tuple_max(a):
  if is_tuple(a):
    return max(tuple_max(x) for x in a)
  else:
    return a


def elem_scale(a, b):
  if is_tuple(a):
    if is_tuple(b):                     # tuple tuple
      assert len(a) == len(b)
      return tuple(elem_scale(x,y) for x,y in zip(a,b))
    else:                               # tuple "int"
      assert False           # Error
  else:
    if is_tuple(b):                     # "int" tuple
      return elem_scale(a, product(b))
    else:                               # "int" "int"
      return a * b


# Inclusive prefix ceil div with output congruent to input a
def shape_div(a, b):
  if is_tuple(a):
    if is_tuple(b):                    # tuple tuple
      assert len(a) == len(b)
      return tuple(shape_div(x,y) for x,y in zip(a,b))
    else:                              # tuple "int"
      #r = [shape_div(a[0],b)] + [shape_div(a[i],b := shape_div(b, product(a[i-1]))) for i in range(1,len(a))]
      r = []
      for v in a:
        r.append(shape_div(v,b))
        b = shape_div(b,product(v))
      return tuple(r)
  else:
    if is_tuple(b):                    # "int" tuple
      return shape_div(a, product(b))
    else:                              # "int" "int"
      assert a % b == 0 or b % a == 0
      #return -(-a // b)      # Python exclusive impl: "//" is always floor div
      if a % b == 0:
        return a // b
      else:
        return signum(a*b)


# Exclusive prefix product with output congruent to input a
def prefix_product(a, init=1):
  if is_tuple(a):
    if is_tuple(init):                 # tuple tuple
      assert len(a) == len(init)
      return tuple(prefix_product(x,i) for x,i in zip(a,init))
    else:                              # tuple "int"
      #r = [prefix_product(a[0],init)] + [prefix_product(a[i],init := init * product(a[i-1])) for i in range(1,len(a))]
      r = []
      for v in a:
        r.append(prefix_product(v,init))
        init = init * product(v)
      return tuple(r)
  else:
    if is_tuple(init):                 # "int" tuple
      assert False           # Error
    else:                              # "int" "int"
      return init


def idx2crd(idx, shape, stride=None):
  if stride is None:
    stride = prefix_product(shape)

  if is_tuple(idx):
    if is_tuple(shape):                # tuple tuple tuple
      assert len(idx) == len(shape) and len(idx) == len(stride)
      return tuple(idx2crd(i, s, d) for i, s, d in zip(idx,shape,stride))
    else:                              # tuple "int" "int"
      assert False           # Error
  else:
    if is_tuple(shape):                # "int" tuple tuple
      assert len(shape) == len(stride)
      return tuple(idx2crd(idx, s, d) for s,d in zip(shape,stride))
    else:                              # "int" "int" "int"
      return (idx // stride) % shape


def crd2idx(crd, shape, stride=None):
  if stride is None:
    stride = prefix_product(shape)

  if is_tuple(crd):
    if is_tuple(shape):                # tuple tuple tuple
      assert len(crd) == len(shape) and len(crd) == len(stride)
      return sum(crd2idx(c, s, d) for c, s, d in zip(crd, shape, stride))
    else:                              # tuple "int" "int"
      assert False, f"crd={crd}, shape={shape}"           # Error
  else:
    if crd is None:
      crd = 0

    if is_tuple(shape):                # "int" tuple tuple
      assert len(shape) == len(stride)
      result = 0
      for i in range(len(shape)-1):
        result += crd2idx(crd % product(shape[i]), shape[i], stride[i])
        crd = crd // product(shape[i])
      return result + crd2idx(crd, shape[-1], stride[-1])
    else:                              # "int" "int" "int"
      return crd * stride


# Transform crd into the dst_shape's iteration space
def crd2crd(crd, dst_shape, src_shape=None):
  if is_tuple(crd):
    if is_tuple(dst_shape):            # tuple tuple
      assert len(crd) == len(dst_shape)
      return tuple(crd2crd(x, y) for x, y in zip(crd,dst_shape))
    else:                              # tuple "int"
      # Ambiguous unless we have src_shape
      assert src_shape is not None
      return crd2idx(crd, src_shape)
  else:
    if is_tuple(dst_shape):            # "int" tuple
      return idx2crd(crd, dst_shape)
    else:                              # "int" "int"
      assert crd < dst_shape
      return crd


# Filter trg according to crd: keep only elements of trg that are paired with None
def slice_(crd: Union[None, tuple, int],

           trg: Union[tuple, int]):
  if is_tuple(crd):
    if is_tuple(trg):                  # tuple tuple
      assert len(crd) == len(trg)
      # match C++ behavior of `filter_tuple` using `tuple_cat(...)`
      return tuple(chain(*filter(lambda x: x != (), [slice_(c, s) for c, s in zip(crd, trg)])))
    else:
      assert False                     # tuple "int" : Error
  elif crd is None:
    # match C++ behavior `return cute::tuple<B>{b};`
    return (trg,)
  else:
    return ()


# Determine if None appears at any of an int_tuples' terminals
def has_none(a: Union[None, tuple, int]):
  if is_tuple(a):
    return any(has_none(v) for v in a)
  else:
    return a is None