File size: 1,579 Bytes
82ea528 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 |
import torch
import torch.nn.functional as F
def find_subsequence(sequence, sub_sequence):
assert sequence.shape[0]==1
sequence = sequence[0]
sub_sequence = sub_sequence[0]
sub_len = len(sub_sequence)
indices = []
windows = sequence.unfold(0, sub_len, 1)
matches = (windows == sub_sequence).all(dim=1)
indices = matches.nonzero().flatten().tolist()
return indices, len(indices), sub_len
import ast
import torch
def multi_slice_to_mask(expr, length):
def process_single_slice(s):
s = s.replace(':', ',').replace(' ', '')
while ',,' in s:
s = s.replace(',,', ',None,')
if s.startswith(','):
s = 'None' + s
if s.endswith(','):
s = s + 'None'
return s
try:
slices = expr.split(',')
mask = torch.zeros(length, dtype=torch.bool)
if expr == "":
return mask
i = 0
while i < len(slices):
if ':' in slices[i]:
slice_expr = process_single_slice(slices[i])
slice_args = ast.literal_eval(f"({slice_expr})")
s = slice(*slice_args)
mask[s] = True
i += 1
else:
idx = ast.literal_eval(slices[i])
if idx < 0:
idx = length + idx
if 0 <= idx < length:
mask[idx] = True
i += 1
return mask
except Exception as e:
raise ValueError(f"Invalid slice expression: {e}")
|