jaxmetaverse's picture
Upload folder using huggingface_hub
82ea528 verified
from torch import Tensor
from comfy.model_base import BaseModel
from .utils_motion import get_sorted_list_via_attr
class LoraHookMode:
MIN_VRAM = "min_vram"
MAX_SPEED = "max_speed"
#MIN_VRAM_LOWVRAM = "min_vram_lowvram"
#MAX_SPEED_LOWVRAM = "max_speed_lowvram"
# Acts simply as a way to track unique LoraHooks
class HookRef:
pass
class LoraHook:
def __init__(self, lora_name: str):
self.lora_name = lora_name
self.lora_keyframe = LoraHookKeyframeGroup()
self.hook_ref = HookRef()
def initialize_timesteps(self, model: BaseModel):
self.lora_keyframe.initialize_timesteps(model)
def reset(self):
self.lora_keyframe.reset()
def get_copy(self):
'''
Copies LoraHook, but maintains same HookRef
'''
c = LoraHook(lora_name=self.lora_name)
c.lora_keyframe = self.lora_keyframe
c.hook_ref = self.hook_ref # same instance that acts as ref
return c
@property
def strength(self):
return self.lora_keyframe.strength
def __eq__(self, other: 'LoraHook'):
return self.__class__ == other.__class__ and self.hook_ref == other.hook_ref
def __hash__(self):
return hash(self.hook_ref)
class LoraHookGroup:
'''
Stores LoRA hooks to apply for conditioning
'''
def __init__(self):
self.hooks: list[LoraHook] = []
def names(self):
names = []
for hook in self.hooks:
names.append(hook.lora_name)
return ",".join(names)
def add(self, hook: LoraHook):
if hook not in self.hooks:
self.hooks.append(hook)
def is_empty(self):
return len(self.hooks) == 0
def contains(self, lora_hook: LoraHook):
return lora_hook in self.hooks
def clone(self):
cloned = LoraHookGroup()
for hook in self.hooks:
cloned.add(hook.get_copy())
return cloned
def clone_and_combine(self, other: 'LoraHookGroup'):
cloned = self.clone()
for hook in other.hooks:
cloned.add(hook.get_copy())
return cloned
def set_keyframes_on_hooks(self, hook_kf: 'LoraHookKeyframeGroup'):
hook_kf = hook_kf.clone()
for hook in self.hooks:
hook.lora_keyframe = hook_kf
@staticmethod
def combine_all_lora_hooks(lora_hooks_list: list['LoraHookGroup'], require_count=1) -> 'LoraHookGroup':
actual: list[LoraHookGroup] = []
for group in lora_hooks_list:
if group is not None:
actual.append(group)
if len(actual) < require_count:
raise Exception(f"Need at least {require_count} LoRA Hooks to combine, but only had {len(actual)}.")
# if only 1 hook, just return itself without any cloning
if len(actual) == 1:
return actual[0]
final_hook: LoraHookGroup = None
for hook in actual:
if final_hook is None:
final_hook = hook.clone()
else:
final_hook = final_hook.clone_and_combine(hook)
return final_hook
class LoraHookKeyframe:
def __init__(self, strength: float, start_percent=0.0, guarantee_steps=1):
self.strength = strength
# scheduling
self.start_percent = float(start_percent)
self.start_t = 999999999.9
self.guarantee_steps = guarantee_steps
def clone(self):
c = LoraHookKeyframe(strength=self.strength,
start_percent=self.start_percent, guarantee_steps=self.guarantee_steps)
c.start_t = self.start_t
return c
class LoraHookKeyframeGroup:
def __init__(self):
self.keyframes: list[LoraHookKeyframe] = []
self._current_keyframe: LoraHookKeyframe = None
self._current_used_steps: int = 0
self._current_index: int = 0
self._curr_t: float = -1
def reset(self):
self._current_keyframe = None
self._current_used_steps = 0
self._current_index = 0
self._curr_t = -1
self._set_first_as_current()
def add(self, keyframe: LoraHookKeyframe):
# add to end of list, then sort
self.keyframes.append(keyframe)
self.keyframes = get_sorted_list_via_attr(self.keyframes, "start_percent")
self._set_first_as_current()
def _set_first_as_current(self):
if len(self.keyframes) > 0:
self._current_keyframe = self.keyframes[0]
else:
self._current_keyframe = None
def has_index(self, index: int) -> int:
return index >= 0 and index < len(self.keyframes)
def is_empty(self) -> bool:
return len(self.keyframes) == 0
def clone(self):
cloned = LoraHookKeyframeGroup()
for keyframe in self.keyframes:
cloned.keyframes.append(keyframe)
cloned._set_first_as_current()
return cloned
def initialize_timesteps(self, model: BaseModel):
for keyframe in self.keyframes:
keyframe.start_t = model.model_sampling.percent_to_sigma(keyframe.start_percent)
def prepare_current_keyframe(self, curr_t: float) -> bool:
if self.is_empty():
return False
if curr_t == self._curr_t:
return False
prev_index = self._current_index
# if met guaranteed steps, look for next keyframe in case need to switch
if self._current_used_steps >= self._current_keyframe.guarantee_steps:
# if has next index, loop through and see if need t oswitch
if self.has_index(self._current_index+1):
for i in range(self._current_index+1, len(self.keyframes)):
eval_c = self.keyframes[i]
# check if start_t is greater or equal to curr_t
# NOTE: t is in terms of sigmas, not percent, so bigger number = earlier step in sampling
if eval_c.start_t >= curr_t:
self._current_index = i
self._current_keyframe = eval_c
self._current_used_steps = 0
# if guarantee_steps greater than zero, stop searching for other keyframes
if self._current_keyframe.guarantee_steps > 0:
break
# if eval_c is outside the percent range, stop looking further
else: break
# update steps current context is used
self._current_used_steps += 1
# update current timestep this was performed on
self._curr_t = curr_t
# return True if keyframe changed, False if no change
return prev_index != self._current_index
# properties shadow those of LoraHookKeyframe
@property
def strength(self):
if self._current_keyframe is not None:
return self._current_keyframe.strength
return 1.0
class COND_CONST:
KEY_LORA_HOOK = "lora_hook"
KEY_DEFAULT_COND = "default_cond"
COND_AREA_DEFAULT = "default"
COND_AREA_MASK_BOUNDS = "mask bounds"
_LIST_COND_AREA = [COND_AREA_DEFAULT, COND_AREA_MASK_BOUNDS]
class TimestepsCond:
def __init__(self, start_percent: float, end_percent: float):
self.start_percent = start_percent
self.end_percent = end_percent
def conditioning_set_values(conditioning, values={}):
c = []
for t in conditioning:
n = [t[0], t[1].copy()]
for k in values:
n[1][k] = values[k]
c.append(n)
return c
def set_lora_hook_for_conditioning(conditioning, lora_hook: LoraHookGroup):
if lora_hook is None:
return conditioning
return conditioning_set_values(conditioning, {COND_CONST.KEY_LORA_HOOK: lora_hook})
def set_timesteps_for_conditioning(conditioning, timesteps_cond: TimestepsCond):
if timesteps_cond is None:
return conditioning
return conditioning_set_values(conditioning, {"start_percent": timesteps_cond.start_percent,
"end_percent": timesteps_cond.end_percent})
def set_mask_for_conditioning(conditioning, mask: Tensor, set_cond_area: str, strength: float):
if mask is None:
return conditioning
set_area_to_bounds = False
if set_cond_area != COND_CONST.COND_AREA_DEFAULT:
set_area_to_bounds = True
if len(mask.shape) < 3:
mask = mask.unsqueeze(0)
return conditioning_set_values(conditioning, {"mask": mask,
"set_area_to_bounds": set_area_to_bounds,
"mask_strength": strength})
def combine_conditioning(conds: list):
combined_conds = []
for cond in conds:
combined_conds.extend(cond)
return combined_conds
def set_mask_conds(conds: list, strength: float, set_cond_area: str,
opt_mask: Tensor=None, opt_lora_hook: LoraHookGroup=None, opt_timesteps: TimestepsCond=None):
masked_conds = []
for c in conds:
# first, apply lora_hook to conditioning, if provided
c = set_lora_hook_for_conditioning(c, opt_lora_hook)
# next, apply mask to conditioning
c = set_mask_for_conditioning(conditioning=c, mask=opt_mask, strength=strength, set_cond_area=set_cond_area)
# apply timesteps, if present
c = set_timesteps_for_conditioning(conditioning=c, timesteps_cond=opt_timesteps)
# finally, apply mask to conditioning and store
masked_conds.append(c)
return masked_conds
def set_mask_and_combine_conds(conds: list, new_conds: list, strength: float=1.0, set_cond_area: str="default",
opt_mask: Tensor=None, opt_lora_hook: LoraHookGroup=None, opt_timesteps: TimestepsCond=None):
combined_conds = []
for c, masked_c in zip(conds, new_conds):
# first, apply lora_hook to new conditioning, if provided
masked_c = set_lora_hook_for_conditioning(masked_c, opt_lora_hook)
# next, apply mask to new conditioning, if provided
masked_c = set_mask_for_conditioning(conditioning=masked_c, mask=opt_mask, set_cond_area=set_cond_area, strength=strength)
# apply timesteps, if present
masked_c = set_timesteps_for_conditioning(conditioning=masked_c, timesteps_cond=opt_timesteps)
# finally, combine with existing conditioning and store
combined_conds.append(combine_conditioning([c, masked_c]))
return combined_conds
def set_unmasked_and_combine_conds(conds: list, new_conds: list,
opt_lora_hook: LoraHookGroup, opt_timesteps: TimestepsCond=None):
combined_conds = []
for c, new_c in zip(conds, new_conds):
# first, apply lora_hook to new conditioning, if provided
new_c = set_lora_hook_for_conditioning(new_c, opt_lora_hook)
# next, add default_cond key to cond so that during sampling, it can be identified
new_c = conditioning_set_values(new_c, {COND_CONST.KEY_DEFAULT_COND: True})
# apply timesteps, if present
new_c = set_timesteps_for_conditioning(conditioning=new_c, timesteps_cond=opt_timesteps)
# finally, combine with existing conditioning and store
combined_conds.append(combine_conditioning([c, new_c]))
return combined_conds