|
from torch import Tensor |
|
|
|
from comfy.model_base import BaseModel |
|
|
|
from .utils_motion import get_sorted_list_via_attr |
|
|
|
|
|
class LoraHookMode: |
|
MIN_VRAM = "min_vram" |
|
MAX_SPEED = "max_speed" |
|
|
|
|
|
|
|
|
|
|
|
class HookRef: |
|
pass |
|
|
|
|
|
class LoraHook: |
|
def __init__(self, lora_name: str): |
|
self.lora_name = lora_name |
|
self.lora_keyframe = LoraHookKeyframeGroup() |
|
self.hook_ref = HookRef() |
|
|
|
def initialize_timesteps(self, model: BaseModel): |
|
self.lora_keyframe.initialize_timesteps(model) |
|
|
|
def reset(self): |
|
self.lora_keyframe.reset() |
|
|
|
|
|
def get_copy(self): |
|
''' |
|
Copies LoraHook, but maintains same HookRef |
|
''' |
|
c = LoraHook(lora_name=self.lora_name) |
|
c.lora_keyframe = self.lora_keyframe |
|
c.hook_ref = self.hook_ref |
|
return c |
|
|
|
@property |
|
def strength(self): |
|
return self.lora_keyframe.strength |
|
|
|
def __eq__(self, other: 'LoraHook'): |
|
return self.__class__ == other.__class__ and self.hook_ref == other.hook_ref |
|
|
|
def __hash__(self): |
|
return hash(self.hook_ref) |
|
|
|
|
|
class LoraHookGroup: |
|
''' |
|
Stores LoRA hooks to apply for conditioning |
|
''' |
|
def __init__(self): |
|
self.hooks: list[LoraHook] = [] |
|
|
|
def names(self): |
|
names = [] |
|
for hook in self.hooks: |
|
names.append(hook.lora_name) |
|
return ",".join(names) |
|
|
|
def add(self, hook: LoraHook): |
|
if hook not in self.hooks: |
|
self.hooks.append(hook) |
|
|
|
def is_empty(self): |
|
return len(self.hooks) == 0 |
|
|
|
def contains(self, lora_hook: LoraHook): |
|
return lora_hook in self.hooks |
|
|
|
def clone(self): |
|
cloned = LoraHookGroup() |
|
for hook in self.hooks: |
|
cloned.add(hook.get_copy()) |
|
return cloned |
|
|
|
def clone_and_combine(self, other: 'LoraHookGroup'): |
|
cloned = self.clone() |
|
for hook in other.hooks: |
|
cloned.add(hook.get_copy()) |
|
return cloned |
|
|
|
def set_keyframes_on_hooks(self, hook_kf: 'LoraHookKeyframeGroup'): |
|
hook_kf = hook_kf.clone() |
|
for hook in self.hooks: |
|
hook.lora_keyframe = hook_kf |
|
|
|
@staticmethod |
|
def combine_all_lora_hooks(lora_hooks_list: list['LoraHookGroup'], require_count=1) -> 'LoraHookGroup': |
|
actual: list[LoraHookGroup] = [] |
|
for group in lora_hooks_list: |
|
if group is not None: |
|
actual.append(group) |
|
if len(actual) < require_count: |
|
raise Exception(f"Need at least {require_count} LoRA Hooks to combine, but only had {len(actual)}.") |
|
|
|
if len(actual) == 1: |
|
return actual[0] |
|
final_hook: LoraHookGroup = None |
|
for hook in actual: |
|
if final_hook is None: |
|
final_hook = hook.clone() |
|
else: |
|
final_hook = final_hook.clone_and_combine(hook) |
|
return final_hook |
|
|
|
|
|
class LoraHookKeyframe: |
|
def __init__(self, strength: float, start_percent=0.0, guarantee_steps=1): |
|
self.strength = strength |
|
|
|
self.start_percent = float(start_percent) |
|
self.start_t = 999999999.9 |
|
self.guarantee_steps = guarantee_steps |
|
|
|
def clone(self): |
|
c = LoraHookKeyframe(strength=self.strength, |
|
start_percent=self.start_percent, guarantee_steps=self.guarantee_steps) |
|
c.start_t = self.start_t |
|
return c |
|
|
|
class LoraHookKeyframeGroup: |
|
def __init__(self): |
|
self.keyframes: list[LoraHookKeyframe] = [] |
|
self._current_keyframe: LoraHookKeyframe = None |
|
self._current_used_steps: int = 0 |
|
self._current_index: int = 0 |
|
self._curr_t: float = -1 |
|
|
|
def reset(self): |
|
self._current_keyframe = None |
|
self._current_used_steps = 0 |
|
self._current_index = 0 |
|
self._curr_t = -1 |
|
self._set_first_as_current() |
|
|
|
def add(self, keyframe: LoraHookKeyframe): |
|
|
|
self.keyframes.append(keyframe) |
|
self.keyframes = get_sorted_list_via_attr(self.keyframes, "start_percent") |
|
self._set_first_as_current() |
|
|
|
def _set_first_as_current(self): |
|
if len(self.keyframes) > 0: |
|
self._current_keyframe = self.keyframes[0] |
|
else: |
|
self._current_keyframe = None |
|
|
|
def has_index(self, index: int) -> int: |
|
return index >= 0 and index < len(self.keyframes) |
|
|
|
def is_empty(self) -> bool: |
|
return len(self.keyframes) == 0 |
|
|
|
def clone(self): |
|
cloned = LoraHookKeyframeGroup() |
|
for keyframe in self.keyframes: |
|
cloned.keyframes.append(keyframe) |
|
cloned._set_first_as_current() |
|
return cloned |
|
|
|
def initialize_timesteps(self, model: BaseModel): |
|
for keyframe in self.keyframes: |
|
keyframe.start_t = model.model_sampling.percent_to_sigma(keyframe.start_percent) |
|
|
|
def prepare_current_keyframe(self, curr_t: float) -> bool: |
|
if self.is_empty(): |
|
return False |
|
if curr_t == self._curr_t: |
|
return False |
|
prev_index = self._current_index |
|
|
|
if self._current_used_steps >= self._current_keyframe.guarantee_steps: |
|
|
|
if self.has_index(self._current_index+1): |
|
for i in range(self._current_index+1, len(self.keyframes)): |
|
eval_c = self.keyframes[i] |
|
|
|
|
|
if eval_c.start_t >= curr_t: |
|
self._current_index = i |
|
self._current_keyframe = eval_c |
|
self._current_used_steps = 0 |
|
|
|
if self._current_keyframe.guarantee_steps > 0: |
|
break |
|
|
|
else: break |
|
|
|
self._current_used_steps += 1 |
|
|
|
self._curr_t = curr_t |
|
|
|
return prev_index != self._current_index |
|
|
|
|
|
@property |
|
def strength(self): |
|
if self._current_keyframe is not None: |
|
return self._current_keyframe.strength |
|
return 1.0 |
|
|
|
|
|
class COND_CONST: |
|
KEY_LORA_HOOK = "lora_hook" |
|
KEY_DEFAULT_COND = "default_cond" |
|
|
|
COND_AREA_DEFAULT = "default" |
|
COND_AREA_MASK_BOUNDS = "mask bounds" |
|
_LIST_COND_AREA = [COND_AREA_DEFAULT, COND_AREA_MASK_BOUNDS] |
|
|
|
|
|
class TimestepsCond: |
|
def __init__(self, start_percent: float, end_percent: float): |
|
self.start_percent = start_percent |
|
self.end_percent = end_percent |
|
|
|
|
|
def conditioning_set_values(conditioning, values={}): |
|
c = [] |
|
for t in conditioning: |
|
n = [t[0], t[1].copy()] |
|
for k in values: |
|
n[1][k] = values[k] |
|
c.append(n) |
|
return c |
|
|
|
def set_lora_hook_for_conditioning(conditioning, lora_hook: LoraHookGroup): |
|
if lora_hook is None: |
|
return conditioning |
|
return conditioning_set_values(conditioning, {COND_CONST.KEY_LORA_HOOK: lora_hook}) |
|
|
|
def set_timesteps_for_conditioning(conditioning, timesteps_cond: TimestepsCond): |
|
if timesteps_cond is None: |
|
return conditioning |
|
return conditioning_set_values(conditioning, {"start_percent": timesteps_cond.start_percent, |
|
"end_percent": timesteps_cond.end_percent}) |
|
|
|
def set_mask_for_conditioning(conditioning, mask: Tensor, set_cond_area: str, strength: float): |
|
if mask is None: |
|
return conditioning |
|
set_area_to_bounds = False |
|
if set_cond_area != COND_CONST.COND_AREA_DEFAULT: |
|
set_area_to_bounds = True |
|
if len(mask.shape) < 3: |
|
mask = mask.unsqueeze(0) |
|
|
|
return conditioning_set_values(conditioning, {"mask": mask, |
|
"set_area_to_bounds": set_area_to_bounds, |
|
"mask_strength": strength}) |
|
|
|
def combine_conditioning(conds: list): |
|
combined_conds = [] |
|
for cond in conds: |
|
combined_conds.extend(cond) |
|
return combined_conds |
|
|
|
def set_mask_conds(conds: list, strength: float, set_cond_area: str, |
|
opt_mask: Tensor=None, opt_lora_hook: LoraHookGroup=None, opt_timesteps: TimestepsCond=None): |
|
masked_conds = [] |
|
for c in conds: |
|
|
|
c = set_lora_hook_for_conditioning(c, opt_lora_hook) |
|
|
|
c = set_mask_for_conditioning(conditioning=c, mask=opt_mask, strength=strength, set_cond_area=set_cond_area) |
|
|
|
c = set_timesteps_for_conditioning(conditioning=c, timesteps_cond=opt_timesteps) |
|
|
|
masked_conds.append(c) |
|
return masked_conds |
|
|
|
def set_mask_and_combine_conds(conds: list, new_conds: list, strength: float=1.0, set_cond_area: str="default", |
|
opt_mask: Tensor=None, opt_lora_hook: LoraHookGroup=None, opt_timesteps: TimestepsCond=None): |
|
combined_conds = [] |
|
for c, masked_c in zip(conds, new_conds): |
|
|
|
masked_c = set_lora_hook_for_conditioning(masked_c, opt_lora_hook) |
|
|
|
masked_c = set_mask_for_conditioning(conditioning=masked_c, mask=opt_mask, set_cond_area=set_cond_area, strength=strength) |
|
|
|
masked_c = set_timesteps_for_conditioning(conditioning=masked_c, timesteps_cond=opt_timesteps) |
|
|
|
combined_conds.append(combine_conditioning([c, masked_c])) |
|
return combined_conds |
|
|
|
def set_unmasked_and_combine_conds(conds: list, new_conds: list, |
|
opt_lora_hook: LoraHookGroup, opt_timesteps: TimestepsCond=None): |
|
combined_conds = [] |
|
for c, new_c in zip(conds, new_conds): |
|
|
|
new_c = set_lora_hook_for_conditioning(new_c, opt_lora_hook) |
|
|
|
new_c = conditioning_set_values(new_c, {COND_CONST.KEY_DEFAULT_COND: True}) |
|
|
|
new_c = set_timesteps_for_conditioning(conditioning=new_c, timesteps_cond=opt_timesteps) |
|
|
|
combined_conds.append(combine_conditioning([c, new_c])) |
|
return combined_conds |
|
|