jaxmetaverse's picture
Upload folder using huggingface_hub
82ea528 verified
import torch
import torch.nn.functional as F
def find_subsequence(sequence, sub_sequence):
assert sequence.shape[0]==1
sequence = sequence[0]
sub_sequence = sub_sequence[0]
sub_len = len(sub_sequence)
indices = []
windows = sequence.unfold(0, sub_len, 1)
matches = (windows == sub_sequence).all(dim=1)
indices = matches.nonzero().flatten().tolist()
return indices, len(indices), sub_len
import ast
import torch
def multi_slice_to_mask(expr, length):
def process_single_slice(s):
s = s.replace(':', ',').replace(' ', '')
while ',,' in s:
s = s.replace(',,', ',None,')
if s.startswith(','):
s = 'None' + s
if s.endswith(','):
s = s + 'None'
return s
try:
slices = expr.split(',')
mask = torch.zeros(length, dtype=torch.bool)
if expr == "":
return mask
i = 0
while i < len(slices):
if ':' in slices[i]:
slice_expr = process_single_slice(slices[i])
slice_args = ast.literal_eval(f"({slice_expr})")
s = slice(*slice_args)
mask[s] = True
i += 1
else:
idx = ast.literal_eval(slices[i])
if idx < 0:
idx = length + idx
if 0 <= idx < length:
mask[idx] = True
i += 1
return mask
except Exception as e:
raise ValueError(f"Invalid slice expression: {e}")