Spaces:
Sleeping
Sleeping
################################################################################################# | |
# | |
# Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
# SPDX-License-Identifier: BSD-3-Clause | |
# | |
# Redistribution and use in source and binary forms, with or without | |
# modification, are permitted provided that the following conditions are met: | |
# | |
# 1. Redistributions of source code must retain the above copyright notice, this | |
# list of conditions and the following disclaimer. | |
# | |
# 2. Redistributions in binary form must reproduce the above copyright notice, | |
# this list of conditions and the following disclaimer in the documentation | |
# and/or other materials provided with the distribution. | |
# | |
# 3. Neither the name of the copyright holder nor the names of its | |
# contributors may be used to endorse or promote products derived from | |
# this software without specific prior written permission. | |
# | |
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | |
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE | |
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | |
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | |
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | |
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | |
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | |
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |
# | |
################################################################################################# | |
""" | |
Functions for manipulating IntTuples | |
""" | |
from functools import reduce | |
from itertools import chain | |
from typing import Union | |
from .typing import Integer | |
def is_int(x): | |
return isinstance(x, Integer) | |
def is_tuple(x): | |
return isinstance(x, tuple) | |
def flatten(t): | |
if is_tuple(t): | |
if len(t) == 0: | |
return () | |
else: | |
return tuple(i for a in t for i in flatten(a)) | |
else: | |
return (t,) | |
def signum(a): | |
return bool(a > 0) - bool(a < 0) | |
def product(a): | |
if is_tuple(a): | |
return reduce(lambda val,elem : val*product(elem), a, 1) | |
else: | |
return a | |
def inner_product(a, b): | |
if is_tuple(a): # tuple tuple | |
assert len(a) == len(b) | |
return sum(inner_product(x,y) for x,y in zip(a,b)) | |
else: # "int" "int" | |
assert not is_tuple(b) | |
return a * b | |
def tuple_max(a): | |
if is_tuple(a): | |
return max(tuple_max(x) for x in a) | |
else: | |
return a | |
def elem_scale(a, b): | |
if is_tuple(a): | |
if is_tuple(b): # tuple tuple | |
assert len(a) == len(b) | |
return tuple(elem_scale(x,y) for x,y in zip(a,b)) | |
else: # tuple "int" | |
assert False # Error | |
else: | |
if is_tuple(b): # "int" tuple | |
return elem_scale(a, product(b)) | |
else: # "int" "int" | |
return a * b | |
# Inclusive prefix ceil div with output congruent to input a | |
def shape_div(a, b): | |
if is_tuple(a): | |
if is_tuple(b): # tuple tuple | |
assert len(a) == len(b) | |
return tuple(shape_div(x,y) for x,y in zip(a,b)) | |
else: # tuple "int" | |
#r = [shape_div(a[0],b)] + [shape_div(a[i],b := shape_div(b, product(a[i-1]))) for i in range(1,len(a))] | |
r = [] | |
for v in a: | |
r.append(shape_div(v,b)) | |
b = shape_div(b,product(v)) | |
return tuple(r) | |
else: | |
if is_tuple(b): # "int" tuple | |
return shape_div(a, product(b)) | |
else: # "int" "int" | |
assert a % b == 0 or b % a == 0 | |
#return -(-a // b) # Python exclusive impl: "//" is always floor div | |
if a % b == 0: | |
return a // b | |
else: | |
return signum(a*b) | |
# Exclusive prefix product with output congruent to input a | |
def prefix_product(a, init=1): | |
if is_tuple(a): | |
if is_tuple(init): # tuple tuple | |
assert len(a) == len(init) | |
return tuple(prefix_product(x,i) for x,i in zip(a,init)) | |
else: # tuple "int" | |
#r = [prefix_product(a[0],init)] + [prefix_product(a[i],init := init * product(a[i-1])) for i in range(1,len(a))] | |
r = [] | |
for v in a: | |
r.append(prefix_product(v,init)) | |
init = init * product(v) | |
return tuple(r) | |
else: | |
if is_tuple(init): # "int" tuple | |
assert False # Error | |
else: # "int" "int" | |
return init | |
def idx2crd(idx, shape, stride=None): | |
if stride is None: | |
stride = prefix_product(shape) | |
if is_tuple(idx): | |
if is_tuple(shape): # tuple tuple tuple | |
assert len(idx) == len(shape) and len(idx) == len(stride) | |
return tuple(idx2crd(i, s, d) for i, s, d in zip(idx,shape,stride)) | |
else: # tuple "int" "int" | |
assert False # Error | |
else: | |
if is_tuple(shape): # "int" tuple tuple | |
assert len(shape) == len(stride) | |
return tuple(idx2crd(idx, s, d) for s,d in zip(shape,stride)) | |
else: # "int" "int" "int" | |
return (idx // stride) % shape | |
def crd2idx(crd, shape, stride=None): | |
if stride is None: | |
stride = prefix_product(shape) | |
if is_tuple(crd): | |
if is_tuple(shape): # tuple tuple tuple | |
assert len(crd) == len(shape) and len(crd) == len(stride) | |
return sum(crd2idx(c, s, d) for c, s, d in zip(crd, shape, stride)) | |
else: # tuple "int" "int" | |
assert False, f"crd={crd}, shape={shape}" # Error | |
else: | |
if crd is None: | |
crd = 0 | |
if is_tuple(shape): # "int" tuple tuple | |
assert len(shape) == len(stride) | |
result = 0 | |
for i in range(len(shape)-1): | |
result += crd2idx(crd % product(shape[i]), shape[i], stride[i]) | |
crd = crd // product(shape[i]) | |
return result + crd2idx(crd, shape[-1], stride[-1]) | |
else: # "int" "int" "int" | |
return crd * stride | |
# Transform crd into the dst_shape's iteration space | |
def crd2crd(crd, dst_shape, src_shape=None): | |
if is_tuple(crd): | |
if is_tuple(dst_shape): # tuple tuple | |
assert len(crd) == len(dst_shape) | |
return tuple(crd2crd(x, y) for x, y in zip(crd,dst_shape)) | |
else: # tuple "int" | |
# Ambiguous unless we have src_shape | |
assert src_shape is not None | |
return crd2idx(crd, src_shape) | |
else: | |
if is_tuple(dst_shape): # "int" tuple | |
return idx2crd(crd, dst_shape) | |
else: # "int" "int" | |
assert crd < dst_shape | |
return crd | |
# Filter trg according to crd: keep only elements of trg that are paired with None | |
def slice_(crd: Union[None, tuple, int], | |
trg: Union[tuple, int]): | |
if is_tuple(crd): | |
if is_tuple(trg): # tuple tuple | |
assert len(crd) == len(trg) | |
# match C++ behavior of `filter_tuple` using `tuple_cat(...)` | |
return tuple(chain(*filter(lambda x: x != (), [slice_(c, s) for c, s in zip(crd, trg)]))) | |
else: | |
assert False # tuple "int" : Error | |
elif crd is None: | |
# match C++ behavior `return cute::tuple<B>{b};` | |
return (trg,) | |
else: | |
return () | |
# Determine if None appears at any of an int_tuples' terminals | |
def has_none(a: Union[None, tuple, int]): | |
if is_tuple(a): | |
return any(has_none(v) for v in a) | |
else: | |
return a is None | |