manbeast3b commited on
Commit
b84b09a
·
verified ·
1 Parent(s): 7f2e11c

Update src/caching.py

Browse files
Files changed (1) hide show
  1. src/caching.py +149 -0
src/caching.py CHANGED
@@ -174,6 +174,155 @@ def apply_cache_on_transformer(
174
  def apply_cache_on_pipe(pipe: DiffusionPipeline, *, shallow_patch: bool = False, **kwargs):
175
  original_call = pipe.__class__.__call__
176
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
  if not getattr(original_call, "_is_cached", False):
178
  @functools.wraps(original_call)
179
  def new_call(self, *args, **kwargs):
 
174
  def apply_cache_on_pipe(pipe: DiffusionPipeline, *, shallow_patch: bool = False, **kwargs):
175
  original_call = pipe.__class__.__call__
176
 
177
+ if not getattr(original_call, "_is_cached", False):
178
+ @functools.wraps(original_call)
179
+ def new_call(self, *args, **kwargs):
180
+ with cache_context(create_cache_context()):
181
+ return original_call(self, *args, **kwargs)
182
+
183
+ pipe.__class__.__call__ = new_call
184
+ new_call._is_cached = True
185
+
186
+ if not shallow_patch:
187
+ apply_cache_on_transformer(pipe.transformer, **kwargs)
188
+
189
+ pipe._is_cached = True
190
+ return pipe
191
+
192
+ @dataclasses.dataclass
193
+ class CacheContext:
194
+ buffers: Dict[str, torch.Tensor] = dataclasses.field(default_factory=dict)
195
+ incremental_name_counters: DefaultDict[str, int] = dataclasses.field(default_factory=lambda: defaultdict(int))
196
+
197
+ def get_buffer(self, name):
198
+ return self.buffers.get(name)
199
+
200
+ def set_buffer(self, name, buffer):
201
+ self.buffers[name] = buffer
202
+
203
+ def clear_buffers(self):
204
+ self.buffers.clear()
205
+
206
+ _current_cache_context = None
207
+
208
+ def create_cache_context():
209
+ return CacheContext()
210
+
211
+ def get_current_cache_context():
212
+ return _current_cache_context
213
+
214
+ def set_current_cache_context(cache_context=None):
215
+ global _current_cache_context
216
+ _current_cache_context = cache_context
217
+
218
+ @contextlib.contextmanager
219
+ def cache_context(cache_context):
220
+ global _current_cache_context
221
+ old_cache_context = _current_cache_context
222
+ _current_cache_context = cache_context
223
+ try:
224
+ yield
225
+ finally:
226
+ _current_cache_context = old_cache_context
227
+
228
+ def are_two_tensors_similar(t1, t2, *, threshold=0.85):
229
+ mean_diff = (t1 - t2).abs().mean()
230
+ mean_t1 = t1.abs().mean()
231
+ diff = mean_diff / mean_t1
232
+ return diff.item() < threshold
233
+
234
+ class CachedTransformerBlocks(torch.nn.Module):
235
+ def __init__(
236
+ self,
237
+ transformer_blocks,
238
+ single_transformer_blocks=None,
239
+ *,
240
+ transformer=None,
241
+ residual_diff_threshold=0.05,
242
+ ):
243
+ super().__init__()
244
+ self.transformer = transformer
245
+ self.transformer_blocks = transformer_blocks
246
+ self.single_transformer_blocks = single_transformer_blocks
247
+ self.residual_diff_threshold = residual_diff_threshold
248
+
249
+ def forward(self, encoder_hidden_states, hidden_states, *args, **kwargs):
250
+ # Important: For Flux, the order is encoder_hidden_states, hidden_states
251
+ original_encoder_states = encoder_hidden_states
252
+
253
+ # Process first block
254
+ encoder_hidden_states, hidden_states = self.transformer_blocks[0](
255
+ encoder_hidden_states, hidden_states, *args, **kwargs
256
+ )
257
+
258
+ # Calculate residual for encoder states
259
+ first_residual = encoder_hidden_states - original_encoder_states
260
+
261
+ cache_context = get_current_cache_context()
262
+ prev_residual = cache_context.get_buffer("first_residual")
263
+ can_use_cache = prev_residual is not None and are_two_tensors_similar(
264
+ prev_residual,
265
+ first_residual,
266
+ threshold=self.residual_diff_threshold
267
+ )
268
+
269
+ if can_use_cache:
270
+ residual = cache_context.get_buffer("residual")
271
+ encoder_hidden_states = encoder_hidden_states + residual
272
+ else:
273
+ cache_context.set_buffer("first_residual", first_residual)
274
+
275
+ # Process remaining blocks
276
+ for block in self.transformer_blocks[1:]:
277
+ encoder_hidden_states, hidden_states = block(
278
+ encoder_hidden_states, hidden_states, *args, **kwargs
279
+ )
280
+
281
+ cache_context.set_buffer("residual", encoder_hidden_states - original_encoder_states)
282
+
283
+ return encoder_hidden_states, hidden_states
284
+
285
+ def apply_cache_on_transformer(
286
+ transformer: FluxTransformer2DModel,
287
+ *,
288
+ residual_diff_threshold=0.05,
289
+ ):
290
+ cached_transformer_blocks = torch.nn.ModuleList([
291
+ CachedTransformerBlocks(
292
+ transformer.transformer_blocks,
293
+ transformer.single_transformer_blocks if hasattr(transformer, 'single_transformer_blocks') else None,
294
+ transformer=transformer,
295
+ residual_diff_threshold=residual_diff_threshold,
296
+ )
297
+ ])
298
+ dummy_single_transformer_blocks = torch.nn.ModuleList()
299
+
300
+ original_forward = transformer.forward
301
+
302
+ @functools.wraps(original_forward)
303
+ def new_forward(self, *args, **kwargs):
304
+ with unittest.mock.patch.object(
305
+ self,
306
+ "transformer_blocks",
307
+ cached_transformer_blocks,
308
+ ), unittest.mock.patch.object(
309
+ self,
310
+ "single_transformer_blocks",
311
+ dummy_single_transformer_blocks,
312
+ ):
313
+ return original_forward(*args, **kwargs)
314
+
315
+ transformer.forward = new_forward.__get__(transformer)
316
+ return transformer
317
+
318
+ def apply_cache_on_pipe(
319
+ pipe: DiffusionPipeline,
320
+ *,
321
+ shallow_patch: bool = False,
322
+ **kwargs,
323
+ ):
324
+ original_call = pipe.__class__.__call__
325
+
326
  if not getattr(original_call, "_is_cached", False):
327
  @functools.wraps(original_call)
328
  def new_call(self, *args, **kwargs):