File size: 11,359 Bytes
82ea528 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 |
from torch import Tensor
from comfy.model_base import BaseModel
from .utils_motion import get_sorted_list_via_attr
class LoraHookMode:
MIN_VRAM = "min_vram"
MAX_SPEED = "max_speed"
#MIN_VRAM_LOWVRAM = "min_vram_lowvram"
#MAX_SPEED_LOWVRAM = "max_speed_lowvram"
# Acts simply as a way to track unique LoraHooks
class HookRef:
pass
class LoraHook:
def __init__(self, lora_name: str):
self.lora_name = lora_name
self.lora_keyframe = LoraHookKeyframeGroup()
self.hook_ref = HookRef()
def initialize_timesteps(self, model: BaseModel):
self.lora_keyframe.initialize_timesteps(model)
def reset(self):
self.lora_keyframe.reset()
def get_copy(self):
'''
Copies LoraHook, but maintains same HookRef
'''
c = LoraHook(lora_name=self.lora_name)
c.lora_keyframe = self.lora_keyframe
c.hook_ref = self.hook_ref # same instance that acts as ref
return c
@property
def strength(self):
return self.lora_keyframe.strength
def __eq__(self, other: 'LoraHook'):
return self.__class__ == other.__class__ and self.hook_ref == other.hook_ref
def __hash__(self):
return hash(self.hook_ref)
class LoraHookGroup:
'''
Stores LoRA hooks to apply for conditioning
'''
def __init__(self):
self.hooks: list[LoraHook] = []
def names(self):
names = []
for hook in self.hooks:
names.append(hook.lora_name)
return ",".join(names)
def add(self, hook: LoraHook):
if hook not in self.hooks:
self.hooks.append(hook)
def is_empty(self):
return len(self.hooks) == 0
def contains(self, lora_hook: LoraHook):
return lora_hook in self.hooks
def clone(self):
cloned = LoraHookGroup()
for hook in self.hooks:
cloned.add(hook.get_copy())
return cloned
def clone_and_combine(self, other: 'LoraHookGroup'):
cloned = self.clone()
for hook in other.hooks:
cloned.add(hook.get_copy())
return cloned
def set_keyframes_on_hooks(self, hook_kf: 'LoraHookKeyframeGroup'):
hook_kf = hook_kf.clone()
for hook in self.hooks:
hook.lora_keyframe = hook_kf
@staticmethod
def combine_all_lora_hooks(lora_hooks_list: list['LoraHookGroup'], require_count=1) -> 'LoraHookGroup':
actual: list[LoraHookGroup] = []
for group in lora_hooks_list:
if group is not None:
actual.append(group)
if len(actual) < require_count:
raise Exception(f"Need at least {require_count} LoRA Hooks to combine, but only had {len(actual)}.")
# if only 1 hook, just return itself without any cloning
if len(actual) == 1:
return actual[0]
final_hook: LoraHookGroup = None
for hook in actual:
if final_hook is None:
final_hook = hook.clone()
else:
final_hook = final_hook.clone_and_combine(hook)
return final_hook
class LoraHookKeyframe:
def __init__(self, strength: float, start_percent=0.0, guarantee_steps=1):
self.strength = strength
# scheduling
self.start_percent = float(start_percent)
self.start_t = 999999999.9
self.guarantee_steps = guarantee_steps
def clone(self):
c = LoraHookKeyframe(strength=self.strength,
start_percent=self.start_percent, guarantee_steps=self.guarantee_steps)
c.start_t = self.start_t
return c
class LoraHookKeyframeGroup:
def __init__(self):
self.keyframes: list[LoraHookKeyframe] = []
self._current_keyframe: LoraHookKeyframe = None
self._current_used_steps: int = 0
self._current_index: int = 0
self._curr_t: float = -1
def reset(self):
self._current_keyframe = None
self._current_used_steps = 0
self._current_index = 0
self._curr_t = -1
self._set_first_as_current()
def add(self, keyframe: LoraHookKeyframe):
# add to end of list, then sort
self.keyframes.append(keyframe)
self.keyframes = get_sorted_list_via_attr(self.keyframes, "start_percent")
self._set_first_as_current()
def _set_first_as_current(self):
if len(self.keyframes) > 0:
self._current_keyframe = self.keyframes[0]
else:
self._current_keyframe = None
def has_index(self, index: int) -> int:
return index >= 0 and index < len(self.keyframes)
def is_empty(self) -> bool:
return len(self.keyframes) == 0
def clone(self):
cloned = LoraHookKeyframeGroup()
for keyframe in self.keyframes:
cloned.keyframes.append(keyframe)
cloned._set_first_as_current()
return cloned
def initialize_timesteps(self, model: BaseModel):
for keyframe in self.keyframes:
keyframe.start_t = model.model_sampling.percent_to_sigma(keyframe.start_percent)
def prepare_current_keyframe(self, curr_t: float) -> bool:
if self.is_empty():
return False
if curr_t == self._curr_t:
return False
prev_index = self._current_index
# if met guaranteed steps, look for next keyframe in case need to switch
if self._current_used_steps >= self._current_keyframe.guarantee_steps:
# if has next index, loop through and see if need t oswitch
if self.has_index(self._current_index+1):
for i in range(self._current_index+1, len(self.keyframes)):
eval_c = self.keyframes[i]
# check if start_t is greater or equal to curr_t
# NOTE: t is in terms of sigmas, not percent, so bigger number = earlier step in sampling
if eval_c.start_t >= curr_t:
self._current_index = i
self._current_keyframe = eval_c
self._current_used_steps = 0
# if guarantee_steps greater than zero, stop searching for other keyframes
if self._current_keyframe.guarantee_steps > 0:
break
# if eval_c is outside the percent range, stop looking further
else: break
# update steps current context is used
self._current_used_steps += 1
# update current timestep this was performed on
self._curr_t = curr_t
# return True if keyframe changed, False if no change
return prev_index != self._current_index
# properties shadow those of LoraHookKeyframe
@property
def strength(self):
if self._current_keyframe is not None:
return self._current_keyframe.strength
return 1.0
class COND_CONST:
KEY_LORA_HOOK = "lora_hook"
KEY_DEFAULT_COND = "default_cond"
COND_AREA_DEFAULT = "default"
COND_AREA_MASK_BOUNDS = "mask bounds"
_LIST_COND_AREA = [COND_AREA_DEFAULT, COND_AREA_MASK_BOUNDS]
class TimestepsCond:
def __init__(self, start_percent: float, end_percent: float):
self.start_percent = start_percent
self.end_percent = end_percent
def conditioning_set_values(conditioning, values={}):
c = []
for t in conditioning:
n = [t[0], t[1].copy()]
for k in values:
n[1][k] = values[k]
c.append(n)
return c
def set_lora_hook_for_conditioning(conditioning, lora_hook: LoraHookGroup):
if lora_hook is None:
return conditioning
return conditioning_set_values(conditioning, {COND_CONST.KEY_LORA_HOOK: lora_hook})
def set_timesteps_for_conditioning(conditioning, timesteps_cond: TimestepsCond):
if timesteps_cond is None:
return conditioning
return conditioning_set_values(conditioning, {"start_percent": timesteps_cond.start_percent,
"end_percent": timesteps_cond.end_percent})
def set_mask_for_conditioning(conditioning, mask: Tensor, set_cond_area: str, strength: float):
if mask is None:
return conditioning
set_area_to_bounds = False
if set_cond_area != COND_CONST.COND_AREA_DEFAULT:
set_area_to_bounds = True
if len(mask.shape) < 3:
mask = mask.unsqueeze(0)
return conditioning_set_values(conditioning, {"mask": mask,
"set_area_to_bounds": set_area_to_bounds,
"mask_strength": strength})
def combine_conditioning(conds: list):
combined_conds = []
for cond in conds:
combined_conds.extend(cond)
return combined_conds
def set_mask_conds(conds: list, strength: float, set_cond_area: str,
opt_mask: Tensor=None, opt_lora_hook: LoraHookGroup=None, opt_timesteps: TimestepsCond=None):
masked_conds = []
for c in conds:
# first, apply lora_hook to conditioning, if provided
c = set_lora_hook_for_conditioning(c, opt_lora_hook)
# next, apply mask to conditioning
c = set_mask_for_conditioning(conditioning=c, mask=opt_mask, strength=strength, set_cond_area=set_cond_area)
# apply timesteps, if present
c = set_timesteps_for_conditioning(conditioning=c, timesteps_cond=opt_timesteps)
# finally, apply mask to conditioning and store
masked_conds.append(c)
return masked_conds
def set_mask_and_combine_conds(conds: list, new_conds: list, strength: float=1.0, set_cond_area: str="default",
opt_mask: Tensor=None, opt_lora_hook: LoraHookGroup=None, opt_timesteps: TimestepsCond=None):
combined_conds = []
for c, masked_c in zip(conds, new_conds):
# first, apply lora_hook to new conditioning, if provided
masked_c = set_lora_hook_for_conditioning(masked_c, opt_lora_hook)
# next, apply mask to new conditioning, if provided
masked_c = set_mask_for_conditioning(conditioning=masked_c, mask=opt_mask, set_cond_area=set_cond_area, strength=strength)
# apply timesteps, if present
masked_c = set_timesteps_for_conditioning(conditioning=masked_c, timesteps_cond=opt_timesteps)
# finally, combine with existing conditioning and store
combined_conds.append(combine_conditioning([c, masked_c]))
return combined_conds
def set_unmasked_and_combine_conds(conds: list, new_conds: list,
opt_lora_hook: LoraHookGroup, opt_timesteps: TimestepsCond=None):
combined_conds = []
for c, new_c in zip(conds, new_conds):
# first, apply lora_hook to new conditioning, if provided
new_c = set_lora_hook_for_conditioning(new_c, opt_lora_hook)
# next, add default_cond key to cond so that during sampling, it can be identified
new_c = conditioning_set_values(new_c, {COND_CONST.KEY_DEFAULT_COND: True})
# apply timesteps, if present
new_c = set_timesteps_for_conditioning(conditioning=new_c, timesteps_cond=opt_timesteps)
# finally, combine with existing conditioning and store
combined_conds.append(combine_conditioning([c, new_c]))
return combined_conds
|