import numpy as np from typing import Callable, Optional, List def ordered_halving(val): bin_str = f"{val:064b}" bin_flip = bin_str[::-1] as_int = int(bin_flip, 2) return as_int / (1 << 64) def does_window_roll_over(window: list[int], num_frames: int) -> tuple[bool, int]: prev_val = -1 for i, val in enumerate(window): val = val % num_frames if val < prev_val: return True, i prev_val = val return False, -1 def shift_window_to_start(window: list[int], num_frames: int): start_val = window[0] for i in range(len(window)): # 1) subtract each element by start_val to move vals relative to the start of all frames # 2) add num_frames and take modulus to get adjusted vals window[i] = ((window[i] - start_val) + num_frames) % num_frames def shift_window_to_end(window: list[int], num_frames: int): # 1) shift window to start shift_window_to_start(window, num_frames) end_val = window[-1] end_delta = num_frames - end_val - 1 for i in range(len(window)): # 2) add end_delta to each val to slide windows to end window[i] = window[i] + end_delta def get_missing_indexes(windows: list[list[int]], num_frames: int) -> list[int]: all_indexes = list(range(num_frames)) for w in windows: for val in w: try: all_indexes.remove(val) except ValueError: pass return all_indexes def uniform_looped( step: int = ..., num_steps: Optional[int] = None, num_frames: int = ..., context_size: Optional[int] = None, context_stride: int = 3, context_overlap: int = 4, closed_loop: bool = True, ): if num_frames <= context_size: yield list(range(num_frames)) return context_stride = min(context_stride, int(np.ceil(np.log2(num_frames / context_size))) + 1) for context_step in 1 << np.arange(context_stride): pad = int(round(num_frames * ordered_halving(step))) for j in range( int(ordered_halving(step) * context_step) + pad, num_frames + pad + (0 if closed_loop else -context_overlap), (context_size * context_step - context_overlap), ): yield [e % num_frames for e in range(j, j + context_size * context_step, context_step)] #from AnimateDiff-Evolved by Kosinkadink (https://github.com/Kosinkadink/ComfyUI-AnimateDiff-Evolved) def uniform_standard( step: int = ..., num_steps: Optional[int] = None, num_frames: int = ..., context_size: Optional[int] = None, context_stride: int = 3, context_overlap: int = 4, closed_loop: bool = True, ): windows = [] if num_frames <= context_size: windows.append(list(range(num_frames))) return windows context_stride = min(context_stride, int(np.ceil(np.log2(num_frames / context_size))) + 1) for context_step in 1 << np.arange(context_stride): pad = int(round(num_frames * ordered_halving(step))) for j in range( int(ordered_halving(step) * context_step) + pad, num_frames + pad + (0 if closed_loop else -context_overlap), (context_size * context_step - context_overlap), ): windows.append([e % num_frames for e in range(j, j + context_size * context_step, context_step)]) # now that windows are created, shift any windows that loop, and delete duplicate windows delete_idxs = [] win_i = 0 while win_i < len(windows): # if window is rolls over itself, need to shift it is_roll, roll_idx = does_window_roll_over(windows[win_i], num_frames) if is_roll: roll_val = windows[win_i][roll_idx] # roll_val might not be 0 for windows of higher strides shift_window_to_end(windows[win_i], num_frames=num_frames) # check if next window (cyclical) is missing roll_val if roll_val not in windows[(win_i+1) % len(windows)]: # need to insert new window here - just insert window starting at roll_val windows.insert(win_i+1, list(range(roll_val, roll_val + context_size))) # delete window if it's not unique for pre_i in range(0, win_i): if windows[win_i] == windows[pre_i]: delete_idxs.append(win_i) break win_i += 1 # reverse delete_idxs so that they will be deleted in an order that doesn't break idx correlation delete_idxs.reverse() for i in delete_idxs: windows.pop(i) return windows def static_standard( step: int = ..., num_steps: Optional[int] = None, num_frames: int = ..., context_size: Optional[int] = None, context_stride: int = 3, context_overlap: int = 4, closed_loop: bool = True, ): windows = [] if num_frames <= context_size: windows.append(list(range(num_frames))) return windows # always return the same set of windows delta = context_size - context_overlap for start_idx in range(0, num_frames, delta): # if past the end of frames, move start_idx back to allow same context_length ending = start_idx + context_size if ending >= num_frames: final_delta = ending - num_frames final_start_idx = start_idx - final_delta windows.append(list(range(final_start_idx, final_start_idx + context_size))) break windows.append(list(range(start_idx, start_idx + context_size))) return windows def get_context_scheduler(name: str) -> Callable: if name == "uniform_looped": return uniform_looped elif name == "uniform_standard": return uniform_standard elif name == "static_standard": return static_standard else: raise ValueError(f"Unknown context_overlap policy {name}") def get_total_steps( scheduler, timesteps: List[int], num_steps: Optional[int] = None, num_frames: int = ..., context_size: Optional[int] = None, context_stride: int = 3, context_overlap: int = 4, closed_loop: bool = True, ): return sum( len( list( scheduler( i, num_steps, num_frames, context_size, context_stride, context_overlap, ) ) ) for i in range(len(timesteps)) )