Spaces:
Sleeping
Sleeping
################################################################################################# | |
# | |
# Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
# SPDX-License-Identifier: BSD-3-Clause | |
# | |
# Redistribution and use in source and binary forms, with or without | |
# modification, are permitted provided that the following conditions are met: | |
# | |
# 1. Redistributions of source code must retain the above copyright notice, this | |
# list of conditions and the following disclaimer. | |
# | |
# 2. Redistributions in binary form must reproduce the above copyright notice, | |
# this list of conditions and the following disclaimer in the documentation | |
# and/or other materials provided with the distribution. | |
# | |
# 3. Neither the name of the copyright holder nor the names of its | |
# contributors may be used to endorse or promote products derived from | |
# this software without specific prior written permission. | |
# | |
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" | |
# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | |
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE | |
# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE | |
# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL | |
# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR | |
# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER | |
# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, | |
# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | |
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | |
# | |
################################################################################################# | |
""" | |
Methods for layout swizzling | |
""" | |
from .layout import * | |
def shiftr(a, s): | |
return a >> s if s > 0 else shiftl(a, -s) | |
def shiftl(a, s): | |
return a << s if s > 0 else shiftr(a, -s) | |
## A generic Swizzle functor | |
# 0bxxxxxxxxxxxxxxxYYYxxxxxxxZZZxxxx | |
# ^--^ Base is the number of least-sig bits to keep constant | |
# ^-^ ^-^ Bits is the number of bits in the mask | |
# ^---------^ Shift is the distance to shift the YYY mask | |
# (pos shifts YYY to the right, neg shifts YYY to the left) | |
# | |
# e.g. Given | |
# 0bxxxxxxxxxxxxxxxxYYxxxxxxxxxZZxxx | |
# the result is | |
# 0bxxxxxxxxxxxxxxxxYYxxxxxxxxxAAxxx where AA = ZZ xor YY | |
# | |
class Swizzle: | |
def __init__(self, bits, base, shift): | |
assert bits >= 0 | |
assert base >= 0 | |
assert abs(shift) >= bits | |
self.bits = bits | |
self.base = base | |
self.shift = shift | |
bit_msk = (1 << bits) - 1 | |
self.yyy_msk = bit_msk << (base + max(0,shift)) | |
self.zzz_msk = bit_msk << (base - min(0,shift)) | |
# operator () (transform integer) | |
def __call__(self, offset): | |
return offset ^ shiftr(offset & self.yyy_msk, self.shift) | |
# Size of the domain | |
def size(self): | |
return 1 << (bits + base + abs(shift)) | |
# Size of the codomain | |
def cosize(self): | |
return self.size() | |
# print and str | |
def __str__(self): | |
return f"SW_{self.bits}_{self.base}_{self.shift}" | |
# error msgs and representation | |
def __repr__(self): | |
return f"Swizzle({self.bits},{self.base},{self.shift})" | |
class ComposedLayout(LayoutBase): | |
def __init__(self, layoutB, offset, layoutA): | |
self.layoutB = layoutB | |
self.offset = offset | |
self.layoutA = layoutA | |
# operator == | |
def __eq__(self, other): | |
return self.layoutB == other.layoutB and self.offset == other.offset and self.layoutA == other.layoutA | |
# operator len(L) (len [rank] like tuples) | |
def __len__(self): | |
return len(self.layoutA) | |
# operator () (map coord to idx) | |
def __call__(self, *args): | |
return self.layoutB(self.offset + self.layoutA(*args)) | |
# operator [] (get-i like tuples) | |
def __getitem__(self, i): | |
return ComposedLayout(self.layoutB, self.offset, self.layoutA[i]) | |
# size(layout) Size of the domain | |
def size(self): | |
return size(self.layoutA) | |
# cosize(layout) Size of the codomain | |
def cosize(self): | |
return cosize(self.layoutB) | |
# print and str | |
def __str__(self): | |
return f"{self.layoutB} o {self.offset} o {self.layoutA}" | |
# error msgs and representation | |
def __repr__(self): | |
return f"ComposedLayout({repr(self.layoutB)},{repr(self.offset)},{repr(self.layoutA)})" | |