Llama-3.1-8B-DALv0.1
/
venv
/lib
/python3.12
/site-packages
/sympy
/tensor
/array
/expressions
/utils.py
import bisect | |
from collections import defaultdict | |
from sympy.combinatorics import Permutation | |
from sympy.core.containers import Tuple | |
from sympy.core.numbers import Integer | |
def _get_mapping_from_subranks(subranks): | |
mapping = {} | |
counter = 0 | |
for i, rank in enumerate(subranks): | |
for j in range(rank): | |
mapping[counter] = (i, j) | |
counter += 1 | |
return mapping | |
def _get_contraction_links(args, subranks, *contraction_indices): | |
mapping = _get_mapping_from_subranks(subranks) | |
contraction_tuples = [[mapping[j] for j in i] for i in contraction_indices] | |
dlinks = defaultdict(dict) | |
for links in contraction_tuples: | |
if len(links) == 2: | |
(arg1, pos1), (arg2, pos2) = links | |
dlinks[arg1][pos1] = (arg2, pos2) | |
dlinks[arg2][pos2] = (arg1, pos1) | |
continue | |
return args, dict(dlinks) | |
def _sort_contraction_indices(pairing_indices): | |
pairing_indices = [Tuple(*sorted(i)) for i in pairing_indices] | |
pairing_indices.sort(key=lambda x: min(x)) | |
return pairing_indices | |
def _get_diagonal_indices(flattened_indices): | |
axes_contraction = defaultdict(list) | |
for i, ind in enumerate(flattened_indices): | |
if isinstance(ind, (int, Integer)): | |
# If the indices is a number, there can be no diagonal operation: | |
continue | |
axes_contraction[ind].append(i) | |
axes_contraction = {k: v for k, v in axes_contraction.items() if len(v) > 1} | |
# Put the diagonalized indices at the end: | |
ret_indices = [i for i in flattened_indices if i not in axes_contraction] | |
diag_indices = list(axes_contraction) | |
diag_indices.sort(key=lambda x: flattened_indices.index(x)) | |
diagonal_indices = [tuple(axes_contraction[i]) for i in diag_indices] | |
ret_indices += diag_indices | |
ret_indices = tuple(ret_indices) | |
return diagonal_indices, ret_indices | |
def _get_argindex(subindices, ind): | |
for i, sind in enumerate(subindices): | |
if ind == sind: | |
return i | |
if isinstance(sind, (set, frozenset)) and ind in sind: | |
return i | |
raise IndexError("%s not found in %s" % (ind, subindices)) | |
def _apply_recursively_over_nested_lists(func, arr): | |
if isinstance(arr, (tuple, list, Tuple)): | |
return tuple(_apply_recursively_over_nested_lists(func, i) for i in arr) | |
elif isinstance(arr, Tuple): | |
return Tuple.fromiter(_apply_recursively_over_nested_lists(func, i) for i in arr) | |
else: | |
return func(arr) | |
def _build_push_indices_up_func_transformation(flattened_contraction_indices): | |
shifts = {0: 0} | |
i = 0 | |
cumulative = 0 | |
while i < len(flattened_contraction_indices): | |
j = 1 | |
while i+j < len(flattened_contraction_indices): | |
if flattened_contraction_indices[i] + j != flattened_contraction_indices[i+j]: | |
break | |
j += 1 | |
cumulative += j | |
shifts[flattened_contraction_indices[i]] = cumulative | |
i += j | |
shift_keys = sorted(shifts.keys()) | |
def func(idx): | |
return shifts[shift_keys[bisect.bisect_right(shift_keys, idx)-1]] | |
def transform(j): | |
if j in flattened_contraction_indices: | |
return None | |
else: | |
return j - func(j) | |
return transform | |
def _build_push_indices_down_func_transformation(flattened_contraction_indices): | |
N = flattened_contraction_indices[-1]+2 | |
shifts = [i for i in range(N) if i not in flattened_contraction_indices] | |
def transform(j): | |
if j < len(shifts): | |
return shifts[j] | |
else: | |
return j + shifts[-1] - len(shifts) + 1 | |
return transform | |
def _apply_permutation_to_list(perm: Permutation, target_list: list): | |
""" | |
Permute a list according to the given permutation. | |
""" | |
new_list = [None for i in range(perm.size)] | |
for i, e in enumerate(target_list): | |
new_list[perm(i)] = e | |
return new_list | |