|
from typing import Union |
|
import math |
|
import torch |
|
from torch import Tensor |
|
|
|
from comfy.model_base import BaseModel |
|
|
|
from .utils_motion import (prepare_mask_batch, extend_to_batch_size, get_combined_multival, resize_multival, |
|
get_sorted_list_via_attr) |
|
|
|
|
|
CONTEXTREF_VERSION = 1 |
|
|
|
|
|
class ContextExtra: |
|
def __init__(self, start_percent: float, end_percent: float): |
|
|
|
self.start_percent = float(start_percent) |
|
self.start_t = 999999999.9 |
|
self.end_percent = float(end_percent) |
|
self.end_t = 0.0 |
|
self.curr_t = 999999999.9 |
|
|
|
def initialize_timesteps(self, model: BaseModel): |
|
self.start_t = model.model_sampling.percent_to_sigma(self.start_percent) |
|
self.end_t = model.model_sampling.percent_to_sigma(self.end_percent) |
|
|
|
def prepare_current(self, t: Tensor): |
|
self.curr_t = t[0] |
|
|
|
def should_run(self): |
|
if self.curr_t > self.start_t or self.curr_t < self.end_t: |
|
return False |
|
return True |
|
|
|
def cleanup(self): |
|
pass |
|
|
|
|
|
|
|
|
|
class ContextRefTune: |
|
def __init__(self, |
|
attn_style_fidelity=0.0, attn_ref_weight=0.0, attn_strength=0.0, |
|
adain_style_fidelity=0.0, adain_ref_weight=0.0, adain_strength=0.0): |
|
|
|
self.attn_style_fidelity = float(attn_style_fidelity) |
|
self.attn_ref_weight = float(attn_ref_weight) |
|
self.attn_strength = float(attn_strength) |
|
|
|
self.adain_style_fidelity = float(adain_style_fidelity) |
|
self.adain_ref_weight = float(adain_ref_weight) |
|
self.adain_strength = float(adain_strength) |
|
|
|
def create_dict(self): |
|
return { |
|
"attn_style_fidelity": self.attn_style_fidelity, |
|
"attn_ref_weight": self.attn_ref_weight, |
|
"attn_strength": self.attn_strength, |
|
"adain_style_fidelity": self.adain_style_fidelity, |
|
"adain_ref_weight": self.adain_ref_weight, |
|
"adain_strength": self.adain_strength, |
|
} |
|
|
|
|
|
class ContextRefMode: |
|
FIRST = "first" |
|
SLIDING = "sliding" |
|
INDEXES = "indexes" |
|
_LIST = [FIRST, SLIDING, INDEXES] |
|
|
|
def __init__(self, mode: str, sliding_width=2, indexes: set[int]=set([0])): |
|
self.mode = mode |
|
self.sliding_width = sliding_width |
|
self.indexes = indexes |
|
self.single_trigger = True |
|
|
|
@classmethod |
|
def init_first(cls): |
|
return ContextRefMode(cls.FIRST) |
|
|
|
@classmethod |
|
def init_sliding(cls, sliding_width: int): |
|
return ContextRefMode(cls.SLIDING, sliding_width=sliding_width) |
|
|
|
@classmethod |
|
def init_indexes(cls, indexes: set[int]): |
|
return ContextRefMode(cls.INDEXES, indexes=indexes) |
|
|
|
|
|
class ContextRefKeyframe: |
|
def __init__(self, mult=1.0, mult_multival: Union[float, Tensor]=None, tune_replace: ContextRefTune=None, mode_replace: ContextRefMode=None, |
|
start_percent=0.0, guarantee_steps=1, inherit_missing=True): |
|
self.mult = mult |
|
self.orig_mult_multival = mult_multival |
|
self.orig_tune_replace = tune_replace |
|
self.orig_mode_replace = mode_replace |
|
self.mult_multival = self.orig_mult_multival |
|
self.tune_replace = self.orig_tune_replace |
|
self.mode_replace = self.orig_mode_replace |
|
|
|
self.start_percent = float(start_percent) |
|
self.guarantee_steps = guarantee_steps |
|
self.inherit_missing = inherit_missing |
|
|
|
def clone(self): |
|
c = ContextRefKeyframe(mult=self.mult, mult_multival=self.orig_mult_multival, tune_replace=self.orig_tune_replace, mode_replace=self.orig_mode_replace, |
|
start_percent=self.start_percent, guarantee_steps=self.guarantee_steps, inherit_missing=self.inherit_missing) |
|
return c |
|
|
|
|
|
class ContextRefKeyframeGroup: |
|
def __init__(self): |
|
self.keyframes: list[ContextRefKeyframe] = [] |
|
self._current_keyframe: NaiveReuseKeyframe = None |
|
self._current_used_steps: int = 0 |
|
self._current_index: int = 0 |
|
self._previous_t = -1 |
|
|
|
def reset(self): |
|
self._current_keyframe = None |
|
self._current_used_steps = 0 |
|
self._current_index = 0 |
|
self._set_first_as_current() |
|
|
|
def add(self, keyframe: ContextRefKeyframe): |
|
|
|
self.keyframes.append(keyframe) |
|
self.keyframes = get_sorted_list_via_attr(self.keyframes, "start_percent") |
|
self._set_first_as_current() |
|
self._prepare_all_keyframe_vals() |
|
|
|
def _set_first_as_current(self): |
|
if len(self.keyframes) > 0: |
|
self._current_keyframe = self.keyframes[0] |
|
else: |
|
self._current_keyframe = None |
|
|
|
def _prepare_all_keyframe_vals(self): |
|
if self.is_empty(): |
|
return |
|
multival = None |
|
tune = None |
|
mode = None |
|
for kf in self.keyframes: |
|
|
|
if not kf.inherit_missing: |
|
multival = None |
|
tune = None |
|
mode = None |
|
|
|
|
|
if kf.orig_mult_multival is None: |
|
kf.mult_multival = multival |
|
else: |
|
kf.mult_multival = kf.orig_mult_multival |
|
|
|
if kf.orig_tune_replace is None: |
|
kf.tune_replace = tune |
|
else: |
|
kf.tune_replace = kf.orig_tune_replace |
|
|
|
if kf.orig_mode_replace is None: |
|
kf.mode_replace = mode |
|
else: |
|
kf.mode_replace = kf.orig_mode_replace |
|
|
|
if kf.mult_multival is not None: |
|
multival = kf.mult_multival |
|
if kf.tune_replace is not None: |
|
tune = kf.tune_replace |
|
if kf.mode_replace is not None: |
|
mode = kf.mode_replace |
|
|
|
def has_index(self, index: int) -> int: |
|
return index >=0 and index < len(self.keyframes) |
|
|
|
def is_empty(self) -> bool: |
|
return len(self.keyframes) == 0 |
|
|
|
def clone(self): |
|
cloned = ContextRefKeyframeGroup() |
|
for keyframe in self.keyframes: |
|
cloned.keyframes.append(keyframe.clone()) |
|
cloned._set_first_as_current() |
|
cloned._prepare_all_keyframe_vals() |
|
return cloned |
|
|
|
def create_list_of_dicts(self): |
|
|
|
c = [] |
|
for kf in self.keyframes: |
|
d = {} |
|
|
|
d["start_percent"] = kf.start_percent |
|
d["guarantee_steps"] = kf.guarantee_steps |
|
d["inherit_missing"] = kf.inherit_missing |
|
|
|
if type(kf.mult_multival) == Tensor: |
|
d["strength"] = kf.mult |
|
d["mask"] = kf.mult_multival |
|
else: |
|
if kf.mult_multival is None: |
|
d["strength"] = kf.mult |
|
else: |
|
d["strength"] = kf.mult * kf.mult_multival |
|
d["mask"] = None |
|
d["tune"] = kf.tune_replace |
|
d["mode"] = kf.mode_replace |
|
|
|
c.append(d) |
|
return c |
|
|
|
|
|
class ContextRef(ContextExtra): |
|
def __init__(self, start_percent: float, end_percent: float, |
|
strength_multival: Union[float, Tensor], tune: ContextRefTune, mode: ContextRefMode, |
|
keyframe: ContextRefKeyframeGroup=None): |
|
super().__init__(start_percent=start_percent, end_percent=end_percent) |
|
self.tune = tune |
|
self.mode = mode |
|
self.keyframe = keyframe if keyframe else ContextRefKeyframeGroup() |
|
self.version = CONTEXTREF_VERSION |
|
|
|
self.strength = 1.0 |
|
self.mask = None |
|
self._strength_multival = strength_multival |
|
self.strength_multival = strength_multival |
|
|
|
@property |
|
def strength_multival(self): |
|
return self.strength_multival |
|
@strength_multival.setter |
|
def strength_multival(self, value): |
|
if value is None: |
|
value = 1.0 |
|
if type(value) == Tensor: |
|
self.strength = 1.0 |
|
self.mask = value |
|
else: |
|
self.strength = value |
|
self.mask = None |
|
self._strength_multival = value |
|
|
|
def should_run(self): |
|
return super().should_run() |
|
|
|
|
|
|
|
|
|
|
|
class NaiveReuseKeyframe: |
|
def __init__(self, mult=1.0, mult_multival: Union[float, Tensor]=None, start_percent=0.0, guarantee_steps=1, inherit_missing=True): |
|
self.mult = mult |
|
self.orig_mult_multival = mult_multival |
|
self.mult_multival = mult_multival |
|
|
|
self.start_percent = float(start_percent) |
|
self.start_t = 999999999.9 |
|
self.guarantee_steps = guarantee_steps |
|
self.inherit_missing = inherit_missing |
|
|
|
def clone(self): |
|
c = NaiveReuseKeyframe(mult=self.mult, mult_multival=self.mult_multival, |
|
start_percent=self.start_percent, guarantee_steps=self.guarantee_steps) |
|
c.start_t = self.start_t |
|
return c |
|
|
|
|
|
class NaiveReuseKeyframeGroup: |
|
def __init__(self): |
|
self.keyframes: list[NaiveReuseKeyframe] = [] |
|
self._current_keyframe: NaiveReuseKeyframe = None |
|
self._current_used_steps: int = 0 |
|
self._current_index: int = 0 |
|
self._previous_t = -1 |
|
|
|
def reset(self): |
|
self._current_keyframe = None |
|
self._current_used_steps = 0 |
|
self._current_index = 0 |
|
self._set_first_as_current() |
|
|
|
def add(self, keyframe: NaiveReuseKeyframe): |
|
|
|
self.keyframes.append(keyframe) |
|
self.keyframes = get_sorted_list_via_attr(self.keyframes, "start_percent") |
|
self._set_first_as_current() |
|
self._prepare_all_keyframe_vals() |
|
|
|
def _set_first_as_current(self): |
|
if len(self.keyframes) > 0: |
|
self._current_keyframe = self.keyframes[0] |
|
else: |
|
self._current_keyframe = None |
|
|
|
def _prepare_all_keyframe_vals(self): |
|
if self.is_empty(): |
|
return |
|
multival = None |
|
for kf in self.keyframes: |
|
|
|
if not kf.inherit_missing: |
|
multival = None |
|
|
|
|
|
if kf.orig_mult_multival is None: |
|
kf.mult_multival = multival |
|
else: |
|
kf.mult_multival = kf.orig_mult_multival |
|
|
|
if kf.mult_multival is not None: |
|
multival = kf.mult_multival |
|
|
|
def has_index(self, index: int) -> int: |
|
return index >=0 and index < len(self.keyframes) |
|
|
|
def is_empty(self) -> bool: |
|
return len(self.keyframes) == 0 |
|
|
|
def clone(self): |
|
cloned = NaiveReuseKeyframeGroup() |
|
for keyframe in self.keyframes: |
|
cloned.keyframes.append(keyframe) |
|
cloned._set_first_as_current() |
|
cloned._prepare_all_keyframe_vals() |
|
return cloned |
|
|
|
def initialize_timesteps(self, model: BaseModel): |
|
for keyframe in self.keyframes: |
|
keyframe.start_t = model.model_sampling.percent_to_sigma(keyframe.start_percent) |
|
|
|
def prepare_current_keyframe(self, t: Tensor): |
|
if self.is_empty(): |
|
return |
|
curr_t: float = t[0] |
|
|
|
if curr_t == self._previous_t: |
|
return |
|
prev_index = self._current_index |
|
|
|
if self._current_used_steps >= self._current_keyframe.guarantee_steps: |
|
|
|
if self.has_index(self._current_index+1): |
|
for i in range(self._current_index+1, len(self.keyframes)): |
|
eval_c = self.keyframes[i] |
|
|
|
|
|
if eval_c.start_t >= curr_t: |
|
self._current_index = i |
|
self._current_keyframe = eval_c |
|
self._current_used_steps = 0 |
|
|
|
if self._current_keyframe.guarantee_steps > 0: |
|
break |
|
|
|
else: break |
|
|
|
self._current_used_steps += 1 |
|
|
|
self._previous_t = curr_t |
|
|
|
|
|
@property |
|
def mult(self): |
|
if self._current_keyframe != None: |
|
return self._current_keyframe.mult |
|
return 1.0 |
|
|
|
@property |
|
def mult_multival(self): |
|
if self._current_keyframe != None: |
|
return self._current_keyframe.mult_multival |
|
return None |
|
|
|
|
|
class NaiveReuse(ContextExtra): |
|
def __init__(self, start_percent: float, end_percent: float, weighted_mean: float, multival_opt: Union[float, Tensor]=None, naivereuse_kf: NaiveReuseKeyframeGroup=None): |
|
super().__init__(start_percent=start_percent, end_percent=end_percent) |
|
self.weighted_mean = weighted_mean |
|
self.orig_multival = multival_opt |
|
self.mask: Tensor = None |
|
self.keyframe = naivereuse_kf if naivereuse_kf else NaiveReuseKeyframeGroup() |
|
self._prev_keyframe = None |
|
|
|
def cleanup(self): |
|
super().cleanup() |
|
del self.mask |
|
self.mask = None |
|
self._prev_keyframe = None |
|
self.keyframe.reset() |
|
|
|
def initialize_timesteps(self, model: BaseModel): |
|
super().initialize_timesteps(model) |
|
self.keyframe.initialize_timesteps(model) |
|
|
|
def prepare_current(self, t: Tensor): |
|
super().prepare_current(t) |
|
self.keyframe.prepare_current_keyframe(t) |
|
|
|
def get_effective_weighted_mean(self, x: Tensor, idxs: list[int]): |
|
if self.orig_multival is None and self.keyframe.mult_multival is None: |
|
return self.weighted_mean * self.keyframe.mult |
|
|
|
keyframe_changed = False |
|
if self.keyframe._current_keyframe != self._prev_keyframe: |
|
keyframe_changed = True |
|
self._prev_keyframe = self.keyframe._current_keyframe |
|
|
|
if type(self.orig_multival) != Tensor and type(self.keyframe.mult_multival) != Tensor: |
|
return self.weighted_mean * self.keyframe.mult * get_combined_multival(self.orig_multival, self.keyframe.mult_multival) |
|
|
|
if self.mask is None or keyframe_changed or self.mask.shape[0] != x.shape[0] or self.mask.shape[-1] != x.shape[-1] or self.mask.shape[-2] != x.shape[-2]: |
|
del self.mask |
|
real_mult_multival = resize_multival(self.keyframe.mult_multival, batch_size=x.shape[0], height=x.shape[-1], width=x.shape[-2]) |
|
self.mask = resize_multival(self.orig_multival, batch_size=x.shape[0], height=x.shape[-1], width=x.shape[-2]) |
|
self.mask = get_combined_multival(self.mask, real_mult_multival) |
|
return self.weighted_mean * self.keyframe.mult * self.mask[idxs].to(dtype=x.dtype, device=x.device) |
|
|
|
def should_run(self): |
|
to_return = super().should_run() |
|
|
|
if self.keyframe.mult_multival is not None and type(self.keyframe.mult_multival) != Tensor and math.isclose(self.keyframe.mult_multival, 0.0): |
|
return False |
|
|
|
return to_return and self.weighted_mean > 0.0 and self.keyframe.mult > 0.0 |
|
|
|
|
|
|
|
class ContextExtrasGroup: |
|
def __init__(self): |
|
self.context_ref: ContextRef = None |
|
self.naive_reuse: NaiveReuse = None |
|
|
|
def get_extras_list(self) -> list[ContextExtra]: |
|
extras_list = [] |
|
if self.context_ref is not None: |
|
extras_list.append(self.context_ref) |
|
if self.naive_reuse is not None: |
|
extras_list.append(self.naive_reuse) |
|
return extras_list |
|
|
|
def initialize_timesteps(self, model: BaseModel): |
|
for extra in self.get_extras_list(): |
|
extra.initialize_timesteps(model) |
|
|
|
def prepare_current(self, t: Tensor): |
|
for extra in self.get_extras_list(): |
|
extra.prepare_current(t) |
|
|
|
def should_run_context_ref(self): |
|
if not self.context_ref: |
|
return False |
|
return self.context_ref.should_run() |
|
|
|
def should_run_naive_reuse(self): |
|
if not self.naive_reuse: |
|
return False |
|
return self.naive_reuse.should_run() |
|
|
|
def add(self, extra: ContextExtra): |
|
if type(extra) == ContextRef: |
|
self.context_ref = extra |
|
elif type(extra) == NaiveReuse: |
|
self.naive_reuse = extra |
|
else: |
|
raise Exception(f"Unrecognized ContextExtras type: {type(extra)}") |
|
|
|
def cleanup(self): |
|
for extra in self.get_extras_list(): |
|
extra.cleanup() |
|
|
|
def clone(self): |
|
cloned = ContextExtrasGroup() |
|
cloned.context_ref = self.context_ref |
|
cloned.naive_reuse = self.naive_reuse |
|
return cloned |
|
|