|
import numpy as np |
|
from typing import Callable, Optional, List |
|
|
|
|
|
def ordered_halving(val): |
|
bin_str = f"{val:064b}" |
|
bin_flip = bin_str[::-1] |
|
as_int = int(bin_flip, 2) |
|
|
|
return as_int / (1 << 64) |
|
|
|
def does_window_roll_over(window: list[int], num_frames: int) -> tuple[bool, int]: |
|
prev_val = -1 |
|
for i, val in enumerate(window): |
|
val = val % num_frames |
|
if val < prev_val: |
|
return True, i |
|
prev_val = val |
|
return False, -1 |
|
|
|
def shift_window_to_start(window: list[int], num_frames: int): |
|
start_val = window[0] |
|
for i in range(len(window)): |
|
|
|
|
|
window[i] = ((window[i] - start_val) + num_frames) % num_frames |
|
|
|
def shift_window_to_end(window: list[int], num_frames: int): |
|
|
|
shift_window_to_start(window, num_frames) |
|
end_val = window[-1] |
|
end_delta = num_frames - end_val - 1 |
|
for i in range(len(window)): |
|
|
|
window[i] = window[i] + end_delta |
|
|
|
def get_missing_indexes(windows: list[list[int]], num_frames: int) -> list[int]: |
|
all_indexes = list(range(num_frames)) |
|
for w in windows: |
|
for val in w: |
|
try: |
|
all_indexes.remove(val) |
|
except ValueError: |
|
pass |
|
return all_indexes |
|
|
|
def uniform_looped( |
|
step: int = ..., |
|
num_steps: Optional[int] = None, |
|
num_frames: int = ..., |
|
context_size: Optional[int] = None, |
|
context_stride: int = 3, |
|
context_overlap: int = 4, |
|
closed_loop: bool = True, |
|
): |
|
if num_frames <= context_size: |
|
yield list(range(num_frames)) |
|
return |
|
|
|
context_stride = min(context_stride, int(np.ceil(np.log2(num_frames / context_size))) + 1) |
|
|
|
for context_step in 1 << np.arange(context_stride): |
|
pad = int(round(num_frames * ordered_halving(step))) |
|
for j in range( |
|
int(ordered_halving(step) * context_step) + pad, |
|
num_frames + pad + (0 if closed_loop else -context_overlap), |
|
(context_size * context_step - context_overlap), |
|
): |
|
yield [e % num_frames for e in range(j, j + context_size * context_step, context_step)] |
|
|
|
|
|
def uniform_standard( |
|
step: int = ..., |
|
num_steps: Optional[int] = None, |
|
num_frames: int = ..., |
|
context_size: Optional[int] = None, |
|
context_stride: int = 3, |
|
context_overlap: int = 4, |
|
closed_loop: bool = True, |
|
): |
|
windows = [] |
|
if num_frames <= context_size: |
|
windows.append(list(range(num_frames))) |
|
return windows |
|
|
|
context_stride = min(context_stride, int(np.ceil(np.log2(num_frames / context_size))) + 1) |
|
|
|
for context_step in 1 << np.arange(context_stride): |
|
pad = int(round(num_frames * ordered_halving(step))) |
|
for j in range( |
|
int(ordered_halving(step) * context_step) + pad, |
|
num_frames + pad + (0 if closed_loop else -context_overlap), |
|
(context_size * context_step - context_overlap), |
|
): |
|
windows.append([e % num_frames for e in range(j, j + context_size * context_step, context_step)]) |
|
|
|
|
|
delete_idxs = [] |
|
win_i = 0 |
|
while win_i < len(windows): |
|
|
|
is_roll, roll_idx = does_window_roll_over(windows[win_i], num_frames) |
|
if is_roll: |
|
roll_val = windows[win_i][roll_idx] |
|
shift_window_to_end(windows[win_i], num_frames=num_frames) |
|
|
|
if roll_val not in windows[(win_i+1) % len(windows)]: |
|
|
|
windows.insert(win_i+1, list(range(roll_val, roll_val + context_size))) |
|
|
|
for pre_i in range(0, win_i): |
|
if windows[win_i] == windows[pre_i]: |
|
delete_idxs.append(win_i) |
|
break |
|
win_i += 1 |
|
|
|
|
|
delete_idxs.reverse() |
|
for i in delete_idxs: |
|
windows.pop(i) |
|
return windows |
|
|
|
def static_standard( |
|
step: int = ..., |
|
num_steps: Optional[int] = None, |
|
num_frames: int = ..., |
|
context_size: Optional[int] = None, |
|
context_stride: int = 3, |
|
context_overlap: int = 4, |
|
closed_loop: bool = True, |
|
): |
|
windows = [] |
|
if num_frames <= context_size: |
|
windows.append(list(range(num_frames))) |
|
return windows |
|
|
|
delta = context_size - context_overlap |
|
for start_idx in range(0, num_frames, delta): |
|
|
|
ending = start_idx + context_size |
|
if ending >= num_frames: |
|
final_delta = ending - num_frames |
|
final_start_idx = start_idx - final_delta |
|
windows.append(list(range(final_start_idx, final_start_idx + context_size))) |
|
break |
|
windows.append(list(range(start_idx, start_idx + context_size))) |
|
return windows |
|
|
|
def get_context_scheduler(name: str) -> Callable: |
|
if name == "uniform_looped": |
|
return uniform_looped |
|
elif name == "uniform_standard": |
|
return uniform_standard |
|
elif name == "static_standard": |
|
return static_standard |
|
else: |
|
raise ValueError(f"Unknown context_overlap policy {name}") |
|
|
|
|
|
def get_total_steps( |
|
scheduler, |
|
timesteps: List[int], |
|
num_steps: Optional[int] = None, |
|
num_frames: int = ..., |
|
context_size: Optional[int] = None, |
|
context_stride: int = 3, |
|
context_overlap: int = 4, |
|
closed_loop: bool = True, |
|
): |
|
return sum( |
|
len( |
|
list( |
|
scheduler( |
|
i, |
|
num_steps, |
|
num_frames, |
|
context_size, |
|
context_stride, |
|
context_overlap, |
|
) |
|
) |
|
) |
|
for i in range(len(timesteps)) |
|
) |
|
|