File size: 12,819 Bytes
82ea528 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 |
import contextlib
import unittest
import torch
from comfy import model_management
from . import first_block_cache
class ApplyFBCacheOnModel:
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"model": ("MODEL", ),
"object_to_patch": (
"STRING",
{
"default": "diffusion_model",
},
),
"residual_diff_threshold": (
"FLOAT",
{
"default":
0.0,
"min":
0.0,
"max":
1.0,
"step":
0.001,
"tooltip":
"Controls the tolerance for caching with lower values being more strict. Setting this to 0 disables the FBCache effect.",
},
),
"start": (
"FLOAT",
{
"default":
0.0,
"step":
0.01,
"max":
1.0,
"min":
0.0,
"tooltip":
"Start time as a percentage of sampling where the FBCache effect can apply. Example: 0.0 would signify 0% (the beginning of sampling), 0.5 would signify 50%.",
},
),
"end": ("FLOAT", {
"default":
1.0,
"step":
0.01,
"max":
1.0,
"min":
0.0,
"tooltip":
"End time as a percentage of sampling where the FBCache effect can apply. Example: 1.0 would signify 100% (the end of sampling), 0.5 would signify 50%.",
}),
"max_consecutive_cache_hits": (
"INT",
{
"default":
-1,
"min":
-1,
"tooltip":
"Allows limiting how many cached results can be used in a row. For example, setting this to 1 will mean there will be at least one full model call after each cached result. Set to 0 to disable FBCache effect, or -1 to allow unlimited consecutive cache hits.",
},
),
}
}
RETURN_TYPES = ("MODEL", )
FUNCTION = "patch"
CATEGORY = "wavespeed"
def patch(
self,
model,
object_to_patch,
residual_diff_threshold,
max_consecutive_cache_hits=-1,
start=0.0,
end=1.0,
):
if residual_diff_threshold <= 0.0 or max_consecutive_cache_hits == 0:
return (model, )
first_block_cache.patch_get_output_data()
using_validation = max_consecutive_cache_hits >= 0 or start > 0 or end < 1
if using_validation:
model_sampling = model.get_model_object("model_sampling")
start_sigma, end_sigma = (float(
model_sampling.percent_to_sigma(pct)) for pct in (start, end))
del model_sampling
@torch.compiler.disable()
def validate_use_cache(use_cached):
nonlocal consecutive_cache_hits
use_cached = use_cached and end_sigma <= current_timestep <= start_sigma
use_cached = use_cached and (max_consecutive_cache_hits < 0
or consecutive_cache_hits
< max_consecutive_cache_hits)
consecutive_cache_hits = consecutive_cache_hits + 1 if use_cached else 0
return use_cached
else:
validate_use_cache = None
prev_timestep = None
prev_input_state = None
current_timestep = None
consecutive_cache_hits = 0
def reset_cache_state():
# Resets the cache state and hits/time tracking variables.
nonlocal prev_input_state, prev_timestep, consecutive_cache_hits
prev_input_state = prev_timestep = None
consecutive_cache_hits = 0
first_block_cache.set_current_cache_context(
first_block_cache.create_cache_context())
def ensure_cache_state(model_input: torch.Tensor, timestep: float):
# Validates the current cache state and hits/time tracking variables
# and triggers a reset if necessary. Also updates current_timestep and
# maintains the cache context sequence number.
nonlocal current_timestep
input_state = (model_input.shape, model_input.dtype, model_input.device)
cache_context = first_block_cache.get_current_cache_context()
# We reset when:
need_reset = (
# The previous timestep or input state is not set
prev_timestep is None or
prev_input_state is None or
# Or dtype/device have changed
prev_input_state[1:] != input_state[1:] or
# Or the input state after the batch dimension has changed
prev_input_state[0][1:] != input_state[0][1:] or
# Or there is no cache context (in this case reset is just making a context)
cache_context is None or
# Or the current timestep is higher than the previous one
timestep > prev_timestep
)
if need_reset:
reset_cache_state()
elif timestep == prev_timestep:
# When the current timestep is the same as the previous, we assume ComfyUI has split up
# the model evaluation into multiple chunks. In this case, we increment the sequence number.
# Note: No need to check if cache_context is None for these branches as need_reset would be True
# if so.
cache_context.sequence_num += 1
elif timestep < prev_timestep:
# When the timestep is less than the previous one, we can reset the context sequence number
cache_context.sequence_num = 0
current_timestep = timestep
def update_cache_state(model_input: torch.Tensor, timestep: float):
# Updates the previous timestep and input state validation variables.
nonlocal prev_timestep, prev_input_state
prev_timestep = timestep
prev_input_state = (model_input.shape, model_input.dtype, model_input.device)
model = model.clone()
diffusion_model = model.get_model_object(object_to_patch)
if diffusion_model.__class__.__name__ in ("UNetModel", "Flux"):
if diffusion_model.__class__.__name__ == "UNetModel":
create_patch_function = first_block_cache.create_patch_unet_model__forward
elif diffusion_model.__class__.__name__ == "Flux":
create_patch_function = first_block_cache.create_patch_flux_forward_orig
else:
raise ValueError(
f"Unsupported model {diffusion_model.__class__.__name__}")
patch_forward = create_patch_function(
diffusion_model,
residual_diff_threshold=residual_diff_threshold,
validate_can_use_cache_function=validate_use_cache,
)
def model_unet_function_wrapper(model_function, kwargs):
try:
input = kwargs["input"]
timestep = kwargs["timestep"]
c = kwargs["c"]
t = timestep[0].item()
ensure_cache_state(input, t)
with patch_forward():
result = model_function(input, timestep, **c)
update_cache_state(input, t)
return result
except Exception as exc:
reset_cache_state()
raise exc from None
else:
is_non_native_ltxv = False
if diffusion_model.__class__.__name__ == "LTXVTransformer3D":
is_non_native_ltxv = True
diffusion_model = diffusion_model.transformer
double_blocks_name = None
single_blocks_name = None
if hasattr(diffusion_model, "transformer_blocks"):
double_blocks_name = "transformer_blocks"
elif hasattr(diffusion_model, "double_blocks"):
double_blocks_name = "double_blocks"
elif hasattr(diffusion_model, "joint_blocks"):
double_blocks_name = "joint_blocks"
else:
raise ValueError(
f"No double blocks found for {diffusion_model.__class__.__name__}"
)
if hasattr(diffusion_model, "single_blocks"):
single_blocks_name = "single_blocks"
if is_non_native_ltxv:
original_create_skip_layer_mask = getattr(
diffusion_model, "create_skip_layer_mask", None)
if original_create_skip_layer_mask is not None:
# original_double_blocks = getattr(diffusion_model,
# double_blocks_name)
def new_create_skip_layer_mask(self, *args, **kwargs):
# with unittest.mock.patch.object(self, double_blocks_name,
# original_double_blocks):
# return original_create_skip_layer_mask(*args, **kwargs)
# return original_create_skip_layer_mask(*args, **kwargs)
raise RuntimeError(
"STG is not supported with FBCache yet")
diffusion_model.create_skip_layer_mask = new_create_skip_layer_mask.__get__(
diffusion_model)
cached_transformer_blocks = torch.nn.ModuleList([
first_block_cache.CachedTransformerBlocks(
None if double_blocks_name is None else getattr(
diffusion_model, double_blocks_name),
None if single_blocks_name is None else getattr(
diffusion_model, single_blocks_name),
residual_diff_threshold=residual_diff_threshold,
validate_can_use_cache_function=validate_use_cache,
cat_hidden_states_first=diffusion_model.__class__.__name__
== "HunyuanVideo",
return_hidden_states_only=diffusion_model.__class__.
__name__ == "LTXVModel" or is_non_native_ltxv,
clone_original_hidden_states=diffusion_model.__class__.
__name__ == "LTXVModel",
return_hidden_states_first=diffusion_model.__class__.
__name__ != "OpenAISignatureMMDITWrapper",
accept_hidden_states_first=diffusion_model.__class__.
__name__ != "OpenAISignatureMMDITWrapper",
)
])
dummy_single_transformer_blocks = torch.nn.ModuleList()
def model_unet_function_wrapper(model_function, kwargs):
try:
input = kwargs["input"]
timestep = kwargs["timestep"]
c = kwargs["c"]
t = timestep[0].item()
ensure_cache_state(input, t)
with unittest.mock.patch.object(
diffusion_model,
double_blocks_name,
cached_transformer_blocks,
), unittest.mock.patch.object(
diffusion_model,
single_blocks_name,
dummy_single_transformer_blocks,
) if single_blocks_name is not None else contextlib.nullcontext(
):
result = model_function(input, timestep, **c)
update_cache_state(input, t)
return result
except Exception as exc:
reset_cache_state()
raise exc from None
model.set_model_unet_function_wrapper(model_unet_function_wrapper)
return (model, )
|