AlanB commited on
Commit
26206b5
·
1 Parent(s): 69e533f

Added sequential_cpu_offload and model_cpu_offload by d1g1t

Browse files
Files changed (1) hide show
  1. pipeline.py +67 -6
pipeline.py CHANGED
@@ -16,7 +16,7 @@ from diffusers import SchedulerMixin, StableDiffusionPipeline
16
  from diffusers.models import AutoencoderKL, UNet2DConditionModel
17
  from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
18
  from diffusers.utils import logging
19
-
20
 
21
  try:
22
  from diffusers.utils import PIL_INTERPOLATION
@@ -281,6 +281,7 @@ def get_weighted_text_embeddings(
281
  skip_weighting (`bool`, *optional*, defaults to `False`):
282
  Skip the weighting. When the parsing is skipped, it is forced True.
283
  """
 
284
  max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
285
  if isinstance(prompt, str):
286
  prompt = [prompt]
@@ -329,7 +330,7 @@ def get_weighted_text_embeddings(
329
  no_boseos_middle=no_boseos_middle,
330
  chunk_length=pipe.tokenizer.model_max_length,
331
  )
332
- prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=pipe.device)
333
  if uncond_prompt is not None:
334
  uncond_tokens, uncond_weights = pad_tokens_and_weights(
335
  uncond_tokens,
@@ -340,7 +341,7 @@ def get_weighted_text_embeddings(
340
  no_boseos_middle=no_boseos_middle,
341
  chunk_length=pipe.tokenizer.model_max_length,
342
  )
343
- uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=pipe.device)
344
 
345
  # get the embeddings
346
  text_embeddings = get_unweighted_text_embeddings(
@@ -349,7 +350,8 @@ def get_weighted_text_embeddings(
349
  pipe.tokenizer.model_max_length,
350
  no_boseos_middle=no_boseos_middle,
351
  )
352
- prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=pipe.device)
 
353
  if uncond_prompt is not None:
354
  uncond_embeddings = get_unweighted_text_embeddings(
355
  pipe,
@@ -357,7 +359,8 @@ def get_weighted_text_embeddings(
357
  pipe.tokenizer.model_max_length,
358
  no_boseos_middle=no_boseos_middle,
359
  )
360
- uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=pipe.device)
 
361
 
362
  # assign weights to the prompts and normalize in the sense of mean
363
  # TODO: should we normalize by chunk or in a whole (current implementation)?
@@ -481,6 +484,59 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
481
  if not hasattr(self, "vae_scale_factor"):
482
  setattr(self, "vae_scale_factor", 2 ** (len(self.vae.config.block_out_channels) - 1))
483
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
484
  @property
485
  def _execution_device(self):
486
  r"""
@@ -488,7 +544,8 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
488
  `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
489
  hooks.
490
  """
491
- if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
 
492
  return self.device
493
  for module in self.unet.modules():
494
  if (
@@ -858,6 +915,10 @@ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
858
  if output_type == "pil":
859
  image = self.numpy_to_pil(image)
860
 
 
 
 
 
861
  if not return_dict:
862
  return image, has_nsfw_concept
863
 
 
16
  from diffusers.models import AutoencoderKL, UNet2DConditionModel
17
  from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
18
  from diffusers.utils import logging
19
+ from diffusers.utils import deprecate, is_accelerate_available, is_accelerate_version
20
 
21
  try:
22
  from diffusers.utils import PIL_INTERPOLATION
 
281
  skip_weighting (`bool`, *optional*, defaults to `False`):
282
  Skip the weighting. When the parsing is skipped, it is forced True.
283
  """
284
+ unet_device = torch.device('cpu') if pipe.unet.device == torch.device('meta') else pipe.unet.device
285
  max_length = (pipe.tokenizer.model_max_length - 2) * max_embeddings_multiples + 2
286
  if isinstance(prompt, str):
287
  prompt = [prompt]
 
330
  no_boseos_middle=no_boseos_middle,
331
  chunk_length=pipe.tokenizer.model_max_length,
332
  )
333
+ prompt_tokens = torch.tensor(prompt_tokens, dtype=torch.long, device=unet_device)
334
  if uncond_prompt is not None:
335
  uncond_tokens, uncond_weights = pad_tokens_and_weights(
336
  uncond_tokens,
 
341
  no_boseos_middle=no_boseos_middle,
342
  chunk_length=pipe.tokenizer.model_max_length,
343
  )
344
+ uncond_tokens = torch.tensor(uncond_tokens, dtype=torch.long, device=unet_device)
345
 
346
  # get the embeddings
347
  text_embeddings = get_unweighted_text_embeddings(
 
350
  pipe.tokenizer.model_max_length,
351
  no_boseos_middle=no_boseos_middle,
352
  )
353
+ text_embeddings = text_embeddings.to(device=unet_device)
354
+ prompt_weights = torch.tensor(prompt_weights, dtype=text_embeddings.dtype, device=unet_device)
355
  if uncond_prompt is not None:
356
  uncond_embeddings = get_unweighted_text_embeddings(
357
  pipe,
 
359
  pipe.tokenizer.model_max_length,
360
  no_boseos_middle=no_boseos_middle,
361
  )
362
+ uncond_embeddings = uncond_embeddings.to(device=unet_device)
363
+ uncond_weights = torch.tensor(uncond_weights, dtype=uncond_embeddings.dtype, device=unet_device)
364
 
365
  # assign weights to the prompts and normalize in the sense of mean
366
  # TODO: should we normalize by chunk or in a whole (current implementation)?
 
484
  if not hasattr(self, "vae_scale_factor"):
485
  setattr(self, "vae_scale_factor", 2 ** (len(self.vae.config.block_out_channels) - 1))
486
 
487
+ def enable_sequential_cpu_offload(self, gpu_id=0):
488
+ r"""
489
+ Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
490
+ text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
491
+ `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
492
+ Note that offloading happens on a submodule basis. Memory savings are higher than with
493
+ `enable_model_cpu_offload`, but performance is lower.
494
+ """
495
+ if is_accelerate_available():
496
+ from accelerate import cpu_offload
497
+ else:
498
+ raise ImportError("Please install accelerate via `pip install accelerate`")
499
+
500
+ device = torch.device(f"cuda:{gpu_id}")
501
+
502
+ if self.device.type != "cpu":
503
+ self.to("cpu", silence_dtype_warnings=True)
504
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
505
+
506
+ for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
507
+ cpu_offload(cpu_offloaded_model, device)
508
+
509
+ if self.safety_checker is not None:
510
+ cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)
511
+
512
+ def enable_model_cpu_offload(self, gpu_id=0):
513
+ r"""
514
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
515
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
516
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
517
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
518
+ """
519
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
520
+ from accelerate import cpu_offload_with_hook
521
+ else:
522
+ raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.")
523
+
524
+ device = torch.device(f"cuda:{gpu_id}")
525
+
526
+ if self.device.type != "cpu":
527
+ self.to("cpu", silence_dtype_warnings=True)
528
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
529
+
530
+ hook = None
531
+ for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
532
+ _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
533
+
534
+ if self.safety_checker is not None:
535
+ _, hook = cpu_offload_with_hook(self.safety_checker, device, prev_module_hook=hook)
536
+
537
+ # We'll offload the last model manually.
538
+ self.final_offload_hook = hook
539
+
540
  @property
541
  def _execution_device(self):
542
  r"""
 
544
  `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
545
  hooks.
546
  """
547
+ #if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
548
+ if not hasattr(self.unet, "_hf_hook"):
549
  return self.device
550
  for module in self.unet.modules():
551
  if (
 
915
  if output_type == "pil":
916
  image = self.numpy_to_pil(image)
917
 
918
+ # 12. Offload last model to CPU
919
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
920
+ self.final_offload_hook.offload()
921
+
922
  if not return_dict:
923
  return image, has_nsfw_concept
924