jaxmetaverse's picture
Upload folder using huggingface_hub
82ea528 verified
from typing import Union
import math
import torch
from torch import Tensor
from comfy.model_base import BaseModel
from .utils_motion import (prepare_mask_batch, extend_to_batch_size, get_combined_multival, resize_multival,
get_sorted_list_via_attr)
CONTEXTREF_VERSION = 1
class ContextExtra:
def __init__(self, start_percent: float, end_percent: float):
# scheduling
self.start_percent = float(start_percent)
self.start_t = 999999999.9
self.end_percent = float(end_percent)
self.end_t = 0.0
self.curr_t = 999999999.9
def initialize_timesteps(self, model: BaseModel):
self.start_t = model.model_sampling.percent_to_sigma(self.start_percent)
self.end_t = model.model_sampling.percent_to_sigma(self.end_percent)
def prepare_current(self, t: Tensor):
self.curr_t = t[0]
def should_run(self):
if self.curr_t > self.start_t or self.curr_t < self.end_t:
return False
return True
def cleanup(self):
pass
################################
# ContextRef
class ContextRefTune:
def __init__(self,
attn_style_fidelity=0.0, attn_ref_weight=0.0, attn_strength=0.0,
adain_style_fidelity=0.0, adain_ref_weight=0.0, adain_strength=0.0):
# attn1
self.attn_style_fidelity = float(attn_style_fidelity)
self.attn_ref_weight = float(attn_ref_weight)
self.attn_strength = float(attn_strength)
# adain
self.adain_style_fidelity = float(adain_style_fidelity)
self.adain_ref_weight = float(adain_ref_weight)
self.adain_strength = float(adain_strength)
def create_dict(self):
return {
"attn_style_fidelity": self.attn_style_fidelity,
"attn_ref_weight": self.attn_ref_weight,
"attn_strength": self.attn_strength,
"adain_style_fidelity": self.adain_style_fidelity,
"adain_ref_weight": self.adain_ref_weight,
"adain_strength": self.adain_strength,
}
class ContextRefMode:
FIRST = "first"
SLIDING = "sliding"
INDEXES = "indexes"
_LIST = [FIRST, SLIDING, INDEXES]
def __init__(self, mode: str, sliding_width=2, indexes: set[int]=set([0])):
self.mode = mode
self.sliding_width = sliding_width
self.indexes = indexes
self.single_trigger = True
@classmethod
def init_first(cls):
return ContextRefMode(cls.FIRST)
@classmethod
def init_sliding(cls, sliding_width: int):
return ContextRefMode(cls.SLIDING, sliding_width=sliding_width)
@classmethod
def init_indexes(cls, indexes: set[int]):
return ContextRefMode(cls.INDEXES, indexes=indexes)
class ContextRefKeyframe:
def __init__(self, mult=1.0, mult_multival: Union[float, Tensor]=None, tune_replace: ContextRefTune=None, mode_replace: ContextRefMode=None,
start_percent=0.0, guarantee_steps=1, inherit_missing=True):
self.mult = mult
self.orig_mult_multival = mult_multival
self.orig_tune_replace = tune_replace
self.orig_mode_replace = mode_replace
self.mult_multival = self.orig_mult_multival
self.tune_replace = self.orig_tune_replace
self.mode_replace = self.orig_mode_replace
# scheduling
self.start_percent = float(start_percent)
self.guarantee_steps = guarantee_steps
self.inherit_missing = inherit_missing
def clone(self):
c = ContextRefKeyframe(mult=self.mult, mult_multival=self.orig_mult_multival, tune_replace=self.orig_tune_replace, mode_replace=self.orig_mode_replace,
start_percent=self.start_percent, guarantee_steps=self.guarantee_steps, inherit_missing=self.inherit_missing)
return c
class ContextRefKeyframeGroup:
def __init__(self):
self.keyframes: list[ContextRefKeyframe] = []
self._current_keyframe: NaiveReuseKeyframe = None
self._current_used_steps: int = 0
self._current_index: int = 0
self._previous_t = -1
def reset(self):
self._current_keyframe = None
self._current_used_steps = 0
self._current_index = 0
self._set_first_as_current()
def add(self, keyframe: ContextRefKeyframe):
# add to end of list, then sort
self.keyframes.append(keyframe)
self.keyframes = get_sorted_list_via_attr(self.keyframes, "start_percent")
self._set_first_as_current()
self._prepare_all_keyframe_vals()
def _set_first_as_current(self):
if len(self.keyframes) > 0:
self._current_keyframe = self.keyframes[0]
else:
self._current_keyframe = None
def _prepare_all_keyframe_vals(self):
if self.is_empty():
return
multival = None
tune = None
mode = None
for kf in self.keyframes:
# if shouldn't inherit, clear cache
if not kf.inherit_missing:
multival = None
tune = None
mode = None
# assign cached values, if origs were None
# Mult #################
if kf.orig_mult_multival is None:
kf.mult_multival = multival
else:
kf.mult_multival = kf.orig_mult_multival
# Tune #################
if kf.orig_tune_replace is None:
kf.tune_replace = tune
else:
kf.tune_replace = kf.orig_tune_replace
# Mode #################
if kf.orig_mode_replace is None:
kf.mode_replace = mode
else:
kf.mode_replace = kf.orig_mode_replace
# save new caches, in case next keyframe inherits missing
if kf.mult_multival is not None:
multival = kf.mult_multival
if kf.tune_replace is not None:
tune = kf.tune_replace
if kf.mode_replace is not None:
mode = kf.mode_replace
def has_index(self, index: int) -> int:
return index >=0 and index < len(self.keyframes)
def is_empty(self) -> bool:
return len(self.keyframes) == 0
def clone(self):
cloned = ContextRefKeyframeGroup()
for keyframe in self.keyframes:
cloned.keyframes.append(keyframe.clone())
cloned._set_first_as_current()
cloned._prepare_all_keyframe_vals()
return cloned
def create_list_of_dicts(self):
# for each keyframe, create a dict representing values relevant to TimestepKeyframe creation in ACN
c = []
for kf in self.keyframes:
d = {}
# scheduling
d["start_percent"] = kf.start_percent
d["guarantee_steps"] = kf.guarantee_steps
d["inherit_missing"] = kf.inherit_missing
# values
if type(kf.mult_multival) == Tensor:
d["strength"] = kf.mult
d["mask"] = kf.mult_multival
else:
if kf.mult_multival is None:
d["strength"] = kf.mult
else:
d["strength"] = kf.mult * kf.mult_multival
d["mask"] = None
d["tune"] = kf.tune_replace
d["mode"] = kf.mode_replace
# add to list
c.append(d)
return c
class ContextRef(ContextExtra):
def __init__(self, start_percent: float, end_percent: float,
strength_multival: Union[float, Tensor], tune: ContextRefTune, mode: ContextRefMode,
keyframe: ContextRefKeyframeGroup=None):
super().__init__(start_percent=start_percent, end_percent=end_percent)
self.tune = tune
self.mode = mode
self.keyframe = keyframe if keyframe else ContextRefKeyframeGroup()
self.version = CONTEXTREF_VERSION
# stuff for ACN usage
self.strength = 1.0
self.mask = None
self._strength_multival = strength_multival
self.strength_multival = strength_multival
@property
def strength_multival(self):
return self.strength_multival
@strength_multival.setter
def strength_multival(self, value):
if value is None:
value = 1.0
if type(value) == Tensor:
self.strength = 1.0
self.mask = value
else:
self.strength = value
self.mask = None
self._strength_multival = value
def should_run(self):
return super().should_run()
#--------------------------------
################################
# NaiveReuse
class NaiveReuseKeyframe:
def __init__(self, mult=1.0, mult_multival: Union[float, Tensor]=None, start_percent=0.0, guarantee_steps=1, inherit_missing=True):
self.mult = mult
self.orig_mult_multival = mult_multival
self.mult_multival = mult_multival
# scheduling
self.start_percent = float(start_percent)
self.start_t = 999999999.9
self.guarantee_steps = guarantee_steps
self.inherit_missing = inherit_missing
def clone(self):
c = NaiveReuseKeyframe(mult=self.mult, mult_multival=self.mult_multival,
start_percent=self.start_percent, guarantee_steps=self.guarantee_steps)
c.start_t = self.start_t
return c
class NaiveReuseKeyframeGroup:
def __init__(self):
self.keyframes: list[NaiveReuseKeyframe] = []
self._current_keyframe: NaiveReuseKeyframe = None
self._current_used_steps: int = 0
self._current_index: int = 0
self._previous_t = -1
def reset(self):
self._current_keyframe = None
self._current_used_steps = 0
self._current_index = 0
self._set_first_as_current()
def add(self, keyframe: NaiveReuseKeyframe):
# add to end of list, then sort
self.keyframes.append(keyframe)
self.keyframes = get_sorted_list_via_attr(self.keyframes, "start_percent")
self._set_first_as_current()
self._prepare_all_keyframe_vals()
def _set_first_as_current(self):
if len(self.keyframes) > 0:
self._current_keyframe = self.keyframes[0]
else:
self._current_keyframe = None
def _prepare_all_keyframe_vals(self):
if self.is_empty():
return
multival = None
for kf in self.keyframes:
# if shouldn't inherit, clear cache
if not kf.inherit_missing:
multival = None
# assign cached values, if origs were None
# Mult #################
if kf.orig_mult_multival is None:
kf.mult_multival = multival
else:
kf.mult_multival = kf.orig_mult_multival
# save new caches, in case next keyframe inherits missing
if kf.mult_multival is not None:
multival = kf.mult_multival
def has_index(self, index: int) -> int:
return index >=0 and index < len(self.keyframes)
def is_empty(self) -> bool:
return len(self.keyframes) == 0
def clone(self):
cloned = NaiveReuseKeyframeGroup()
for keyframe in self.keyframes:
cloned.keyframes.append(keyframe)
cloned._set_first_as_current()
cloned._prepare_all_keyframe_vals()
return cloned
def initialize_timesteps(self, model: BaseModel):
for keyframe in self.keyframes:
keyframe.start_t = model.model_sampling.percent_to_sigma(keyframe.start_percent)
def prepare_current_keyframe(self, t: Tensor):
if self.is_empty():
return
curr_t: float = t[0]
# if curr_t same as before, do nothing as step already accounted for
if curr_t == self._previous_t:
return
prev_index = self._current_index
# if met guaranteed steps, look for next keyframe in case need to switch
if self._current_used_steps >= self._current_keyframe.guarantee_steps:
# if has next index, loop through and see if need t oswitch
if self.has_index(self._current_index+1):
for i in range(self._current_index+1, len(self.keyframes)):
eval_c = self.keyframes[i]
# check if start_t is greater or equal to curr_t
# NOTE: t is in terms of sigmas, not percent, so bigger number = earlier step in sampling
if eval_c.start_t >= curr_t:
self._current_index = i
self._current_keyframe = eval_c
self._current_used_steps = 0
# if guarantee_steps greater than zero, stop searching for other keyframes
if self._current_keyframe.guarantee_steps > 0:
break
# if eval_c is outside the percent range, stop looking further
else: break
# update steps current context is used
self._current_used_steps += 1
# update previous_t
self._previous_t = curr_t
# properties shadow those of NaiveReuseKeyframe
@property
def mult(self):
if self._current_keyframe != None:
return self._current_keyframe.mult
return 1.0
@property
def mult_multival(self):
if self._current_keyframe != None:
return self._current_keyframe.mult_multival
return None
class NaiveReuse(ContextExtra):
def __init__(self, start_percent: float, end_percent: float, weighted_mean: float, multival_opt: Union[float, Tensor]=None, naivereuse_kf: NaiveReuseKeyframeGroup=None):
super().__init__(start_percent=start_percent, end_percent=end_percent)
self.weighted_mean = weighted_mean
self.orig_multival = multival_opt
self.mask: Tensor = None
self.keyframe = naivereuse_kf if naivereuse_kf else NaiveReuseKeyframeGroup()
self._prev_keyframe = None
def cleanup(self):
super().cleanup()
del self.mask
self.mask = None
self._prev_keyframe = None
self.keyframe.reset()
def initialize_timesteps(self, model: BaseModel):
super().initialize_timesteps(model)
self.keyframe.initialize_timesteps(model)
def prepare_current(self, t: Tensor):
super().prepare_current(t)
self.keyframe.prepare_current_keyframe(t)
def get_effective_weighted_mean(self, x: Tensor, idxs: list[int]):
if self.orig_multival is None and self.keyframe.mult_multival is None:
return self.weighted_mean * self.keyframe.mult
# check if keyframe changed
keyframe_changed = False
if self.keyframe._current_keyframe != self._prev_keyframe:
keyframe_changed = True
self._prev_keyframe = self.keyframe._current_keyframe
if type(self.orig_multival) != Tensor and type(self.keyframe.mult_multival) != Tensor:
return self.weighted_mean * self.keyframe.mult * get_combined_multival(self.orig_multival, self.keyframe.mult_multival)
if self.mask is None or keyframe_changed or self.mask.shape[0] != x.shape[0] or self.mask.shape[-1] != x.shape[-1] or self.mask.shape[-2] != x.shape[-2]:
del self.mask
real_mult_multival = resize_multival(self.keyframe.mult_multival, batch_size=x.shape[0], height=x.shape[-1], width=x.shape[-2])
self.mask = resize_multival(self.orig_multival, batch_size=x.shape[0], height=x.shape[-1], width=x.shape[-2])
self.mask = get_combined_multival(self.mask, real_mult_multival)
return self.weighted_mean * self.keyframe.mult * self.mask[idxs].to(dtype=x.dtype, device=x.device)
def should_run(self):
to_return = super().should_run()
# if keyframe has 0.0 val, should not run
if self.keyframe.mult_multival is not None and type(self.keyframe.mult_multival) != Tensor and math.isclose(self.keyframe.mult_multival, 0.0):
return False
# if weighted_mean is 0.0, then reuse will take no effect anyway
return to_return and self.weighted_mean > 0.0 and self.keyframe.mult > 0.0
#--------------------------------
class ContextExtrasGroup:
def __init__(self):
self.context_ref: ContextRef = None
self.naive_reuse: NaiveReuse = None
def get_extras_list(self) -> list[ContextExtra]:
extras_list = []
if self.context_ref is not None:
extras_list.append(self.context_ref)
if self.naive_reuse is not None:
extras_list.append(self.naive_reuse)
return extras_list
def initialize_timesteps(self, model: BaseModel):
for extra in self.get_extras_list():
extra.initialize_timesteps(model)
def prepare_current(self, t: Tensor):
for extra in self.get_extras_list():
extra.prepare_current(t)
def should_run_context_ref(self):
if not self.context_ref:
return False
return self.context_ref.should_run()
def should_run_naive_reuse(self):
if not self.naive_reuse:
return False
return self.naive_reuse.should_run()
def add(self, extra: ContextExtra):
if type(extra) == ContextRef:
self.context_ref = extra
elif type(extra) == NaiveReuse:
self.naive_reuse = extra
else:
raise Exception(f"Unrecognized ContextExtras type: {type(extra)}")
def cleanup(self):
for extra in self.get_extras_list():
extra.cleanup()
def clone(self):
cloned = ContextExtrasGroup()
cloned.context_ref = self.context_ref
cloned.naive_reuse = self.naive_reuse
return cloned