|
from typing import Union |
|
|
|
import torch |
|
import torchvision |
|
from PIL import Image, ImageFont, ImageDraw |
|
|
|
import numpy as np |
|
from torch import Tensor |
|
|
|
import comfy.samplers |
|
from comfy.model_base import BaseModel |
|
from comfy.model_patcher import ModelPatcher |
|
|
|
from .context_extras import ContextExtrasGroup |
|
from .utils_motion import get_sorted_list_via_attr |
|
|
|
|
|
class ContextFuseMethod: |
|
FLAT = "flat" |
|
PYRAMID = "pyramid" |
|
RELATIVE = "relative" |
|
RANDOM = "random" |
|
GAUSS_SIGMA = "gauss-sigma" |
|
GAUSS_SIGMA_INV = "gauss-sigma inverse" |
|
DELAYED_REVERSE_SAWTOOTH = "delayed reverse sawtooth" |
|
PYRAMID_SIGMA = "pyramid-sigma" |
|
PYRAMID_SIGMA_INV = "pyramid-sigma inverse" |
|
|
|
LIST = [PYRAMID, FLAT, DELAYED_REVERSE_SAWTOOTH, PYRAMID_SIGMA, PYRAMID_SIGMA_INV, GAUSS_SIGMA, GAUSS_SIGMA_INV, RANDOM] |
|
LIST_STATIC = [PYRAMID, RELATIVE, FLAT, DELAYED_REVERSE_SAWTOOTH, PYRAMID_SIGMA, PYRAMID_SIGMA_INV, GAUSS_SIGMA, GAUSS_SIGMA_INV, RANDOM] |
|
|
|
|
|
class ContextType: |
|
UNIFORM_WINDOW = "uniform window" |
|
|
|
|
|
class ContextOptions: |
|
def __init__(self, context_length: int=None, context_stride: int=None, context_overlap: int=None, |
|
context_schedule: str=None, closed_loop: bool=False, fuse_method: str=ContextFuseMethod.FLAT, |
|
use_on_equal_length: bool=False, view_options: 'ContextOptions'=None, |
|
start_percent=0.0, guarantee_steps=1): |
|
|
|
self.context_length = context_length |
|
self.context_stride = context_stride |
|
self.context_overlap = context_overlap |
|
self.context_schedule = context_schedule |
|
self.closed_loop = closed_loop |
|
self.fuse_method = fuse_method |
|
self.sync_context_to_pe = False |
|
self.use_on_equal_length = use_on_equal_length |
|
self.view_options = view_options.clone() if view_options else view_options |
|
|
|
self.start_percent = float(start_percent) |
|
self.start_t = 999999999.9 |
|
self.guarantee_steps = guarantee_steps |
|
|
|
self._step: int = 0 |
|
|
|
@property |
|
def step(self): |
|
return self._step |
|
@step.setter |
|
def step(self, value: int): |
|
self._step = value |
|
if self.view_options: |
|
self.view_options.step = value |
|
|
|
def clone(self): |
|
n = ContextOptions(context_length=self.context_length, context_stride=self.context_stride, |
|
context_overlap=self.context_overlap, context_schedule=self.context_schedule, |
|
closed_loop=self.closed_loop, fuse_method=self.fuse_method, |
|
use_on_equal_length=self.use_on_equal_length, view_options=self.view_options, |
|
start_percent=self.start_percent, guarantee_steps=self.guarantee_steps) |
|
n.start_t = self.start_t |
|
return n |
|
|
|
|
|
class ContextOptionsGroup: |
|
def __init__(self): |
|
self.contexts: list[ContextOptions] = [] |
|
self.extras = ContextExtrasGroup() |
|
self._current_context: ContextOptions = None |
|
self._current_used_steps: int = 0 |
|
self._current_index: int = 0 |
|
self._previous_t = -1 |
|
self._step = 0 |
|
|
|
def reset(self): |
|
self._current_context = None |
|
self._current_used_steps = 0 |
|
self._current_index = 0 |
|
self._previous_t = -1 |
|
self.step = 0 |
|
self._set_first_as_current() |
|
self.extras.cleanup() |
|
|
|
@property |
|
def step(self): |
|
return self._step |
|
@step.setter |
|
def step(self, value: int): |
|
self._step = value |
|
if self._current_context is not None: |
|
self._current_context.step = value |
|
|
|
@classmethod |
|
def default(cls): |
|
def_context = ContextOptions() |
|
new_group = ContextOptionsGroup() |
|
new_group.add(def_context) |
|
return new_group |
|
|
|
def add(self, context: ContextOptions): |
|
|
|
self.contexts.append(context) |
|
self.contexts = get_sorted_list_via_attr(self.contexts, "start_percent") |
|
self._set_first_as_current() |
|
|
|
def add_to_start(self, context: ContextOptions): |
|
|
|
self.contexts.insert(0, context) |
|
self.contexts = get_sorted_list_via_attr(self.contexts, "start_percent") |
|
self._set_first_as_current() |
|
|
|
def has_index(self, index: int) -> int: |
|
return index >=0 and index < len(self.contexts) |
|
|
|
def is_empty(self) -> bool: |
|
return len(self.contexts) == 0 |
|
|
|
def clone(self): |
|
cloned = ContextOptionsGroup() |
|
cloned.extras = self.extras.clone() |
|
for context in self.contexts: |
|
cloned.contexts.append(context) |
|
cloned._set_first_as_current() |
|
return cloned |
|
|
|
def initialize_timesteps(self, model: BaseModel): |
|
for context in self.contexts: |
|
context.start_t = model.model_sampling.percent_to_sigma(context.start_percent) |
|
self.extras.initialize_timesteps(model) |
|
|
|
def prepare_current(self, t: Tensor): |
|
self.prepare_current_context(t) |
|
self.extras.prepare_current(t) |
|
|
|
def prepare_current_context(self, t: Tensor): |
|
curr_t: float = t[0] |
|
|
|
if curr_t == self._previous_t: |
|
return |
|
prev_index = self._current_index |
|
|
|
if self._current_used_steps >= self._current_context.guarantee_steps: |
|
|
|
if self.has_index(self._current_index+1): |
|
for i in range(self._current_index+1, len(self.contexts)): |
|
eval_c = self.contexts[i] |
|
|
|
|
|
if eval_c.start_t >= curr_t: |
|
self._current_index = i |
|
self._current_context = eval_c |
|
self._current_used_steps = 0 |
|
|
|
if self._current_context.guarantee_steps > 0: |
|
break |
|
|
|
else: |
|
break |
|
|
|
self._current_used_steps += 1 |
|
|
|
self._previous_t = curr_t |
|
|
|
def _set_first_as_current(self): |
|
if len(self.contexts) > 0: |
|
self._current_context = self.contexts[0] |
|
|
|
|
|
@property |
|
def context_length(self): |
|
return self._current_context.context_length |
|
|
|
@property |
|
def context_overlap(self): |
|
return self._current_context.context_overlap |
|
|
|
@property |
|
def context_stride(self): |
|
return self._current_context.context_stride |
|
|
|
@property |
|
def context_schedule(self): |
|
return self._current_context.context_schedule |
|
|
|
@property |
|
def closed_loop(self): |
|
return self._current_context.closed_loop |
|
|
|
@property |
|
def fuse_method(self): |
|
return self._current_context.fuse_method |
|
|
|
@property |
|
def use_on_equal_length(self): |
|
return self._current_context.use_on_equal_length |
|
|
|
@property |
|
def view_options(self): |
|
return self._current_context.view_options |
|
|
|
|
|
class ContextSchedules: |
|
UNIFORM_LOOPED = "looped_uniform" |
|
UNIFORM_STANDARD = "standard_uniform" |
|
STATIC_STANDARD = "standard_static" |
|
BATCHED = "batched" |
|
VIEW_AS_CONTEXT = "view_as_context" |
|
SVD_EXTENSION = "svd_extension" |
|
|
|
LEGACY_UNIFORM_LOOPED = "uniform" |
|
LEGACY_UNIFORM_SCHEDULE_LIST = [LEGACY_UNIFORM_LOOPED] |
|
|
|
|
|
|
|
def create_windows_uniform_looped(num_frames: int, opts: Union[ContextOptionsGroup, ContextOptions]): |
|
windows = [] |
|
if num_frames < opts.context_length: |
|
windows.append(list(range(num_frames))) |
|
return windows |
|
|
|
context_stride = min(opts.context_stride, int(np.ceil(np.log2(num_frames / opts.context_length))) + 1) |
|
|
|
for context_step in 1 << np.arange(context_stride): |
|
pad = int(round(num_frames * ordered_halving(opts.step))) |
|
for j in range( |
|
int(ordered_halving(opts.step) * context_step) + pad, |
|
num_frames + pad + (0 if opts.closed_loop else -opts.context_overlap), |
|
(opts.context_length * context_step - opts.context_overlap), |
|
): |
|
windows.append([e % num_frames for e in range(j, j + opts.context_length * context_step, context_step)]) |
|
|
|
return windows |
|
|
|
|
|
def create_windows_uniform_standard(num_frames: int, opts: Union[ContextOptionsGroup, ContextOptions]): |
|
|
|
|
|
|
|
windows = [] |
|
if num_frames <= opts.context_length: |
|
windows.append(list(range(num_frames))) |
|
return windows |
|
|
|
context_stride = min(opts.context_stride, int(np.ceil(np.log2(num_frames / opts.context_length))) + 1) |
|
|
|
for context_step in 1 << np.arange(context_stride): |
|
pad = int(round(num_frames * ordered_halving(opts.step))) |
|
for j in range( |
|
int(ordered_halving(opts.step) * context_step) + pad, |
|
num_frames + pad + (-opts.context_overlap), |
|
(opts.context_length * context_step - opts.context_overlap), |
|
): |
|
windows.append([e % num_frames for e in range(j, j + opts.context_length * context_step, context_step)]) |
|
|
|
|
|
delete_idxs = [] |
|
win_i = 0 |
|
while win_i < len(windows): |
|
|
|
is_roll, roll_idx = does_window_roll_over(windows[win_i], num_frames) |
|
if is_roll: |
|
roll_val = windows[win_i][roll_idx] |
|
shift_window_to_end(windows[win_i], num_frames=num_frames) |
|
|
|
if roll_val not in windows[(win_i+1) % len(windows)]: |
|
|
|
windows.insert(win_i+1, list(range(roll_val, roll_val + opts.context_length))) |
|
|
|
for pre_i in range(0, win_i): |
|
if windows[win_i] == windows[pre_i]: |
|
delete_idxs.append(win_i) |
|
break |
|
win_i += 1 |
|
|
|
|
|
delete_idxs.reverse() |
|
for i in delete_idxs: |
|
windows.pop(i) |
|
|
|
return windows |
|
|
|
|
|
def create_windows_static_standard(num_frames: int, opts: Union[ContextOptionsGroup, ContextOptions]): |
|
windows = [] |
|
if num_frames <= opts.context_length: |
|
windows.append(list(range(num_frames))) |
|
return windows |
|
|
|
delta = opts.context_length - opts.context_overlap |
|
for start_idx in range(0, num_frames, delta): |
|
|
|
ending = start_idx + opts.context_length |
|
if ending >= num_frames: |
|
final_delta = ending - num_frames |
|
final_start_idx = start_idx - final_delta |
|
windows.append(list(range(final_start_idx, final_start_idx + opts.context_length))) |
|
break |
|
windows.append(list(range(start_idx, start_idx + opts.context_length))) |
|
return windows |
|
|
|
|
|
def create_windows_batched(num_frames: int, opts: Union[ContextOptionsGroup, ContextOptions]): |
|
windows = [] |
|
if num_frames <= opts.context_length: |
|
windows.append(list(range(num_frames))) |
|
return windows |
|
|
|
|
|
|
|
for start_idx in range(0, num_frames, opts.context_length): |
|
windows.append(list(range(start_idx, min(start_idx + opts.context_length, num_frames)))) |
|
return windows |
|
|
|
|
|
def create_windows_default(num_frames: int, opts: Union[ContextOptionsGroup, ContextOptions]): |
|
return [list(range(num_frames))] |
|
|
|
|
|
def get_context_windows(num_frames: int, opts: Union[ContextOptionsGroup, ContextOptions]): |
|
context_func = CONTEXT_MAPPING.get(opts.context_schedule, None) |
|
if not context_func: |
|
raise ValueError(f"Unknown context_schedule '{opts.context_schedule}'.") |
|
return context_func(num_frames, opts) |
|
|
|
|
|
CONTEXT_MAPPING = { |
|
ContextSchedules.UNIFORM_LOOPED: create_windows_uniform_looped, |
|
ContextSchedules.UNIFORM_STANDARD: create_windows_uniform_standard, |
|
ContextSchedules.STATIC_STANDARD: create_windows_static_standard, |
|
ContextSchedules.BATCHED: create_windows_batched, |
|
ContextSchedules.SVD_EXTENSION: create_windows_batched, |
|
ContextSchedules.VIEW_AS_CONTEXT: create_windows_default, |
|
} |
|
|
|
|
|
def get_context_weights(num_frames: int, fuse_method: str, sigma: Tensor = None): |
|
weights_func = FUSE_MAPPING.get(fuse_method, None) |
|
if not weights_func: |
|
raise ValueError(f"Unknown fuse_method '{fuse_method}'.") |
|
return weights_func(num_frames, sigma=sigma ) |
|
|
|
|
|
def create_weights_flat(length: int, **kwargs) -> list[float]: |
|
|
|
return [1.0] * length |
|
|
|
def create_weights_pyramid(length: int, **kwargs) -> list[float]: |
|
|
|
|
|
if length % 2 == 0: |
|
max_weight = length // 2 |
|
weight_sequence = list(range(1, max_weight + 1, 1)) + list(range(max_weight, 0, -1)) |
|
else: |
|
max_weight = (length + 1) // 2 |
|
weight_sequence = list(range(1, max_weight, 1)) + [max_weight] + list(range(max_weight - 1, 0, -1)) |
|
return weight_sequence |
|
|
|
def create_weights_random(length: int, **kwargs) -> list[float]: |
|
if length % 2 == 0: |
|
max_weight = length // 2 |
|
else: |
|
max_weight = (length + 1) // 2 |
|
return list(np.random.random(length)*max_weight+0.001) |
|
|
|
def create_weights_gauss_sigma(length: int, **kwargs) -> list[float]: |
|
sigma = 1.0 + 8.0*(min(4.0, kwargs["sigma"].mean().cpu()) / 4.0) |
|
ax = np.linspace(-(length - 1) / 2., (length - 1) / 2., length) |
|
w = np.exp(-0.5 * np.square(ax) / np.square(sigma)) |
|
if length % 2 == 0: |
|
max_weight = length // 2 |
|
else: |
|
max_weight = (length + 1) // 2 |
|
w *= max_weight / np.linalg.norm(w) |
|
|
|
return list(w) |
|
|
|
def create_weights_gauss_sigma_inv(length: int, **kwargs) -> list[float]: |
|
sigma = 1.0 + 8.0*(1.0-min(4.0, kwargs["sigma"].mean().cpu()) / 4.0) |
|
ax = np.linspace(-(length - 1) / 2., (length - 1) / 2., length) |
|
w = np.exp(-0.5 * np.square(ax) / np.square(sigma)) |
|
if length % 2 == 0: |
|
max_weight = length // 2 |
|
else: |
|
max_weight = (length + 1) // 2 |
|
w *= max_weight / np.linalg.norm(w) |
|
|
|
return list(w) |
|
|
|
def create_weights_pyramid_sigma_inv(length: int, **kwargs) -> list[float]: |
|
sigma = min(4.0, kwargs["sigma"].mean().cpu()) / 4.0 |
|
|
|
if length % 2 == 0: |
|
max_weight = length // 2 |
|
weight_sequence = np.array(list(range(1, max_weight + 1, 1)) + list(range(max_weight, 0, -1))) |
|
weight_sequence2 = np.array([-max_weight]*(max_weight-1) +[max_weight,max_weight] + [-max_weight]*(max_weight-1)) |
|
else: |
|
max_weight = (length + 1) // 2 |
|
weight_sequence = list(range(1, max_weight, 1)) + [max_weight] + list(range(max_weight - 1, 0, -1)) |
|
weight_sequence2 = np.array([-max_weight]*(max_weight) +[max_weight] + [-max_weight]*(max_weight-1)) |
|
weight_sequence = (sigma * weight_sequence2 + (1.0-sigma) * weight_sequence).clip(0.001,max_weight) |
|
|
|
return list(weight_sequence) |
|
|
|
def create_weights_pyramid_sigma(length: int, **kwargs) -> list[float]: |
|
sigma = min(4.0, kwargs["sigma"].mean().cpu()) / 4.0 |
|
|
|
if length % 2 == 0: |
|
max_weight = length // 2 |
|
weight_sequence = np.array(list(range(1, max_weight + 1, 1)) + list(range(max_weight, 0, -1))) |
|
weight_sequence2 = np.array([-max_weight]*(max_weight-1) +[max_weight,max_weight] + [-max_weight]*(max_weight-1)) |
|
else: |
|
max_weight = (length + 1) // 2 |
|
weight_sequence = list(range(1, max_weight, 1)) + [max_weight] + list(range(max_weight - 1, 0, -1)) |
|
weight_sequence2 = np.array([-max_weight]*(max_weight) +[max_weight] + [-max_weight]*(max_weight-1)) |
|
weight_sequence = (sigma * weight_sequence + (1.0-sigma) * weight_sequence2).clip(0.001,max_weight) |
|
|
|
return list(weight_sequence) |
|
|
|
def create_weights_delayed_reverse_sawtooth(length: int, **kwargs) -> list[float]: |
|
|
|
|
|
if length % 2 == 0: |
|
max_weight = length // 2 |
|
weight_sequence = [0.01]*(max_weight-1) + [max_weight] + list(range(max_weight, 0, -1)) |
|
else: |
|
max_weight = (length + 1) // 2 |
|
weight_sequence = [0.01]*max_weight + [max_weight] + list(range(max_weight - 1, 0, -1)) |
|
|
|
return weight_sequence |
|
|
|
|
|
FUSE_MAPPING = { |
|
ContextFuseMethod.FLAT: create_weights_flat, |
|
ContextFuseMethod.PYRAMID: create_weights_pyramid, |
|
ContextFuseMethod.RELATIVE: create_weights_pyramid, |
|
ContextFuseMethod.GAUSS_SIGMA: create_weights_gauss_sigma, |
|
ContextFuseMethod.GAUSS_SIGMA_INV: create_weights_gauss_sigma_inv, |
|
ContextFuseMethod.RANDOM: create_weights_random, |
|
ContextFuseMethod.DELAYED_REVERSE_SAWTOOTH: create_weights_delayed_reverse_sawtooth, |
|
ContextFuseMethod.PYRAMID_SIGMA: create_weights_pyramid_sigma, |
|
ContextFuseMethod.PYRAMID_SIGMA_INV: create_weights_pyramid_sigma_inv, |
|
} |
|
|
|
|
|
|
|
def ordered_halving(val): |
|
|
|
bin_str = f"{val:064b}" |
|
|
|
bin_flip = bin_str[::-1] |
|
|
|
as_int = int(bin_flip, 2) |
|
|
|
|
|
return as_int / (1 << 64) |
|
|
|
|
|
def get_missing_indexes(windows: list[list[int]], num_frames: int) -> list[int]: |
|
all_indexes = list(range(num_frames)) |
|
for w in windows: |
|
for val in w: |
|
try: |
|
all_indexes.remove(val) |
|
except ValueError: |
|
pass |
|
return all_indexes |
|
|
|
|
|
def does_window_roll_over(window: list[int], num_frames: int) -> tuple[bool, int]: |
|
prev_val = -1 |
|
for i, val in enumerate(window): |
|
val = val % num_frames |
|
if val < prev_val: |
|
return True, i |
|
prev_val = val |
|
return False, -1 |
|
|
|
|
|
def shift_window_to_start(window: list[int], num_frames: int): |
|
start_val = window[0] |
|
for i in range(len(window)): |
|
|
|
|
|
window[i] = ((window[i] - start_val) + num_frames) % num_frames |
|
|
|
|
|
def shift_window_to_end(window: list[int], num_frames: int): |
|
|
|
shift_window_to_start(window, num_frames) |
|
end_val = window[-1] |
|
end_delta = num_frames - end_val - 1 |
|
for i in range(len(window)): |
|
|
|
window[i] = window[i] + end_delta |
|
|
|
|
|
|
|
|
|
|
|
class Colors: |
|
BLACK = (0, 0, 0) |
|
WHITE = (255, 255, 255) |
|
RED = (255, 0, 0) |
|
GREEN = (0, 255, 0) |
|
BLUE = (0, 0, 255) |
|
YELLOW = (255, 255, 0) |
|
MAGENTA = (255, 0, 255) |
|
CYAN = (0, 255, 255) |
|
|
|
|
|
class BorderWidth: |
|
INDEXES = 2 |
|
CONTEXT = 4 |
|
|
|
|
|
class VisualizeSettings: |
|
def __init__(self, img_width: int, video_length: int): |
|
self.video_length = video_length |
|
self.img_width = img_width |
|
self.grid = img_width // video_length |
|
self.img_height = self.grid * 5 |
|
self.pil_to_tensor = torchvision.transforms.Compose([torchvision.transforms.PILToTensor()]) |
|
self.font_size = int(self.grid * 0.5) |
|
self.font = ImageFont.load_default(size=self.font_size) |
|
|
|
self.title_font = ImageFont.load_default(size=int(self.font_size * 1.2)) |
|
|
|
self.background_color = Colors.BLACK |
|
self.grid_outline_color = Colors.WHITE |
|
self.start_idx_fill_color = Colors.MAGENTA |
|
self.subidx_end_color = Colors.YELLOW |
|
|
|
self.context_color = Colors.GREEN |
|
self.view_color = Colors.RED |
|
|
|
|
|
class GridDisplay: |
|
def __init__(self, draw: ImageDraw.ImageDraw, vs: VisualizeSettings, home_x: int=0, home_y: int=0): |
|
self.home_x = home_x |
|
self.home_y = home_y |
|
self.draw = draw |
|
self.vs = vs |
|
|
|
|
|
def get_text_xy(input: str, font: ImageFont, x: int, y: int, centered=True): |
|
return (x, y,) |
|
|
|
|
|
def draw_text(text: str, font: ImageFont, gd: GridDisplay, x: int, y: int, color=Colors.WHITE, centered=True): |
|
x, y = get_text_xy(text, font, x, y, centered=centered) |
|
gd.draw.text(xy=(gd.home_x+x, gd.home_y+y), text=text, fill=color, font=font) |
|
|
|
|
|
def draw_first_grid_row(total_length: int, gd: GridDisplay, start_idx=-1): |
|
vs = gd.vs |
|
|
|
for i in range(total_length): |
|
x1 = gd.home_x+(vs.grid*i) |
|
y1 = gd.home_y |
|
x2 = x1 + vs.grid |
|
y2 = y1 + vs.grid |
|
|
|
fill = None |
|
if i==start_idx: |
|
fill=vs.start_idx_fill_color |
|
gd.draw.rectangle(xy=(x1, y1, x2, y2), fill=fill, outline=vs.grid_outline_color, width=BorderWidth.INDEXES) |
|
draw_text(text=str(i), font=vs.font, gd=gd, x=vs.grid*i, y=0) |
|
|
|
|
|
def draw_subidxs(window: list[int], gd: GridDisplay, y_grid_offset: int, color: tuple): |
|
vs = gd.vs |
|
|
|
y_offset = vs.grid * y_grid_offset |
|
for i, val in enumerate(window): |
|
x1 = gd.home_x+(vs.grid*val) |
|
y1 = gd.home_y+y_offset |
|
x2 = x1 + vs.grid |
|
y2 = y1 + vs.grid |
|
fill_color = color |
|
|
|
if i == 0 or i == len(window)-1: |
|
fill_color = vs.subidx_end_color |
|
gd.draw.rectangle(xy=(x1, y1, x2, y2), fill=fill_color, outline=color, width=BorderWidth.CONTEXT) |
|
|
|
|
|
def draw_context(window: list[int], gd: GridDisplay): |
|
draw_subidxs(window=window, gd=gd, y_grid_offset=1, color=gd.vs.context_color) |
|
|
|
|
|
def draw_view(window: list[int], gd: GridDisplay): |
|
draw_subidxs(window=window, gd=gd, y_grid_offset=2, color=gd.vs.view_color) |
|
|
|
|
|
def generate_context_visualization(model: ModelPatcher, context_opts: ContextOptionsGroup=None, sampler_name: str=None, scheduler: str=None, |
|
width=1440, height=200, video_length=32, |
|
steps=None, start_step=None, end_step=None, sigmas=None, force_full_denoise=False, denoise=None): |
|
if context_opts is None: |
|
context_opts = ContextOptionsGroup.default() |
|
params = model.get_attachment("ADE_params") |
|
if params is not None: |
|
context_opts = params.context_options |
|
context_opts = context_opts.clone() |
|
vs = VisualizeSettings(width, video_length) |
|
all_imgs = [] |
|
|
|
if sigmas is None: |
|
sampler = comfy.samplers.KSampler( |
|
model=model, steps=steps, device="cpu", sampler=sampler_name, scheduler=scheduler, |
|
denoise=denoise, model_options=model.model_options, |
|
) |
|
sigmas = sampler.sigmas |
|
if end_step is not None and end_step < (len(sigmas) - 1): |
|
sigmas = sigmas[:end_step + 1] |
|
if force_full_denoise: |
|
sigmas[-1] = 0 |
|
if start_step is not None: |
|
if start_step < (len(sigmas) - 1): |
|
sigmas = sigmas[start_step:] |
|
|
|
sigmas = sigmas[:-1] |
|
|
|
context_opts.reset() |
|
context_opts.initialize_timesteps(model.model) |
|
|
|
if start_step is None: |
|
start_step = 0 |
|
if steps is None: |
|
steps = len(sigmas) |
|
|
|
for i, t in enumerate(sigmas): |
|
|
|
context_opts.prepare_current([t]) |
|
context_opts.step = start_step+i |
|
|
|
|
|
context_active = True |
|
if context_opts.context_length is None: |
|
context_active = False |
|
elif video_length < context_opts.context_length: |
|
context_active = False |
|
elif video_length == context_opts.context_length and not context_opts.use_on_equal_length: |
|
context_active = False |
|
|
|
if context_active: |
|
context_windows = get_context_windows(num_frames=video_length, opts=context_opts) |
|
else: |
|
context_windows = [list(range(video_length))] |
|
start_idx = -1 |
|
for j,window in enumerate(context_windows): |
|
repeat_count = 0 |
|
view_windows = [] |
|
total_repeats = 1 |
|
view_options = context_opts.view_options |
|
if view_options is not None: |
|
view_active = True |
|
if len(window) < view_options.context_length: |
|
view_active = False |
|
elif video_length == view_options.context_length and not view_options.use_on_equal_length: |
|
view_active = False |
|
if view_active: |
|
view_windows = get_context_windows(num_frames=len(window), opts=view_options) |
|
total_repeats = len(view_windows) |
|
while total_repeats > repeat_count: |
|
|
|
frame: Image = Image.new(mode="RGB", size=(vs.img_width, vs.img_height), color=vs.background_color) |
|
draw = ImageDraw.Draw(frame) |
|
gd = GridDisplay(draw=draw, vs=vs, home_x=0, home_y=vs.grid) |
|
|
|
if len(view_windows) > 0: |
|
converted_view = [window[x] for x in view_windows[repeat_count]] |
|
draw_view(window=converted_view, gd=gd) |
|
|
|
title_str = f"{context_opts.context_schedule} - Step {context_opts.step+1}/{steps} (Context {j+1}/{len(context_windows)})" |
|
if len(view_windows) > 0: |
|
title_str = f"{title_str} (View {repeat_count+1}/{len(view_windows)})" |
|
draw_text(text=title_str, font=vs.title_font, gd=gd, x=0-gd.home_x, y=0-gd.home_y, centered=False) |
|
|
|
if j == 0: |
|
start_idx = window[0] |
|
draw_first_grid_row(total_length=video_length, gd=gd, start_idx=start_idx) |
|
|
|
draw_context(window=window, gd=gd) |
|
|
|
img: Tensor = vs.pil_to_tensor(frame) |
|
all_imgs.append(img) |
|
repeat_count += 1 |
|
|
|
images = torch.stack(all_imgs) |
|
images = images.movedim(1, -1).to(torch.float32) |
|
return images |
|
|