AlanB commited on
Commit
181c850
·
1 Parent(s): 151612e

Big changes with schedulermixin and pipeline.. Hope it doesn't break anything.

Browse files
Files changed (1) hide show
  1. pipeline.py +261 -537
pipeline.py CHANGED
@@ -9,38 +9,13 @@ import sys
9
  from tqdm.auto import tqdm
10
 
11
  import PIL
12
- from diffusers.configuration_utils import FrozenDict
13
  from diffusers.models import AutoencoderKL, UNet2DConditionModel
14
- from diffusers.pipeline_utils import DiffusionPipeline
15
- from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
16
- from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
17
- from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler, DPMSolverMultistepScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler
18
- from diffusers.utils import deprecate, is_accelerate_available, logging
19
-
20
- # TODO: remove and import from diffusers.utils when the new version of diffusers is released
21
- from packaging import version
22
  from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
23
 
24
 
25
- if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"):
26
- PIL_INTERPOLATION = {
27
- "linear": PIL.Image.Resampling.BILINEAR,
28
- "bilinear": PIL.Image.Resampling.BILINEAR,
29
- "bicubic": PIL.Image.Resampling.BICUBIC,
30
- "lanczos": PIL.Image.Resampling.LANCZOS,
31
- "nearest": PIL.Image.Resampling.NEAREST,
32
- }
33
- else:
34
- PIL_INTERPOLATION = {
35
- "linear": PIL.Image.LINEAR,
36
- "bilinear": PIL.Image.BILINEAR,
37
- "bicubic": PIL.Image.BICUBIC,
38
- "lanczos": PIL.Image.LANCZOS,
39
- "nearest": PIL.Image.NEAREST,
40
- }
41
- # ------------------------------------------------------------------------------
42
-
43
-
44
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
45
 
46
  re_attention = re.compile(
@@ -149,7 +124,7 @@ def parse_prompt_attention(text):
149
  return res
150
 
151
 
152
- def get_prompts_with_weights(pipe: DiffusionPipeline, prompt: List[str], max_length: int):
153
  r"""
154
  Tokenize a list of prompts and return its tokens with weights of each token.
155
 
@@ -210,7 +185,7 @@ def pad_tokens_and_weights(tokens, weights, max_length, bos, eos, no_boseos_midd
210
 
211
 
212
  def get_unweighted_text_embeddings(
213
- pipe: DiffusionPipeline,
214
  text_input: torch.Tensor,
215
  chunk_length: int,
216
  no_boseos_middle: Optional[bool] = True,
@@ -250,10 +225,10 @@ def get_unweighted_text_embeddings(
250
 
251
 
252
  def get_weighted_text_embeddings(
253
- pipe: DiffusionPipeline,
254
  prompt: Union[str, List[str]],
255
  uncond_prompt: Optional[Union[str, List[str]]] = None,
256
- max_embeddings_multiples: Optional[int] = 1,
257
  no_boseos_middle: Optional[bool] = False,
258
  skip_parsing: Optional[bool] = False,
259
  skip_weighting: Optional[bool] = False,
@@ -267,14 +242,14 @@ def get_weighted_text_embeddings(
267
  Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.
268
 
269
  Args:
270
- pipe (`DiffusionPipeline`):
271
  Pipe to provide access to the tokenizer and the text encoder.
272
  prompt (`str` or `List[str]`):
273
  The prompt or prompts to guide the image generation.
274
  uncond_prompt (`str` or `List[str]`):
275
  The unconditional prompt or prompts for guide the image generation. If unconditional prompt
276
  is provided, the embeddings of prompt and uncond_prompt are concatenated.
277
- max_embeddings_multiples (`int`, *optional*, defaults to `1`):
278
  The max multiple length of prompt embeddings compared to the max output length of text encoder.
279
  no_boseos_middle (`bool`, *optional*, defaults to `False`):
280
  If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and
@@ -390,11 +365,11 @@ def preprocess_image(image):
390
  return 2.0 * image - 1.0
391
 
392
 
393
- def preprocess_mask(mask):
394
  mask = mask.convert("L")
395
  w, h = mask.size
396
  w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
397
- mask = mask.resize((w // 8, h // 8), resample=PIL_INTERPOLATION["nearest"])
398
  mask = np.array(mask).astype(np.float32) / 255.0
399
  mask = np.tile(mask, (4, 1, 1))
400
  mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
@@ -403,7 +378,7 @@ def preprocess_mask(mask):
403
  return mask
404
 
405
 
406
- class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
407
  r"""
408
  Pipeline for text-to-image generation using Stable Diffusion without tokens length limit, and support parsing
409
  weighting in prompt.
@@ -438,51 +413,12 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
438
  text_encoder: CLIPTextModel,
439
  tokenizer: CLIPTokenizer,
440
  unet: UNet2DConditionModel,
441
- scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler, DPMSolverMultistepScheduler, EulerDiscreteScheduler, EulerAncestralDiscreteScheduler],
442
  safety_checker: StableDiffusionSafetyChecker,
443
  feature_extractor: CLIPFeatureExtractor,
 
444
  ):
445
- super().__init__()
446
-
447
- if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
448
- deprecation_message = (
449
- f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
450
- f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
451
- "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
452
- " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
453
- " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
454
- " file"
455
- )
456
- deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
457
- new_config = dict(scheduler.config)
458
- new_config["steps_offset"] = 1
459
- scheduler._internal_dict = FrozenDict(new_config)
460
-
461
- if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
462
- deprecation_message = (
463
- f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
464
- " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
465
- " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
466
- " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
467
- " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
468
- )
469
- deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
470
- new_config = dict(scheduler.config)
471
- new_config["clip_sample"] = False
472
- scheduler._internal_dict = FrozenDict(new_config)
473
-
474
- #if safety_checker is None:
475
- if False:
476
- logger.warn(
477
- f"You have disabled the safety checker for {self.__class__} by passing `safety_checker=None`. Ensure"
478
- " that you abide to the conditions of the Stable Diffusion license and do not expose unfiltered"
479
- " results in services or applications open to the public. Both the diffusers team and Hugging Face"
480
- " strongly recommend to keep the safety filter enabled in all public facing circumstances, disabling"
481
- " it only for use-cases that involve analyzing network behavior or auditing its results. For more"
482
- " information, please have a look at https://github.com/huggingface/diffusers/pull/254 ."
483
- )
484
-
485
- self.register_modules(
486
  vae=vae,
487
  text_encoder=text_encoder,
488
  tokenizer=tokenizer,
@@ -490,89 +426,171 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
490
  scheduler=scheduler,
491
  safety_checker=safety_checker,
492
  feature_extractor=feature_extractor,
 
493
  )
494
 
495
- def enable_xformers_memory_efficient_attention(self):
 
 
 
 
 
 
 
 
496
  r"""
497
- Enable memory efficient attention as implemented in xformers.
498
-
499
- When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
500
- time. Speed up at training time is not guaranteed.
501
 
502
- Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
503
- is used.
 
 
 
 
 
 
 
 
 
 
 
 
504
  """
505
- self.unet.set_use_memory_efficient_attention_xformers(True)
506
 
507
- def disable_xformers_memory_efficient_attention(self):
508
- r"""
509
- Disable memory efficient attention as implemented in xformers.
510
- """
511
- self.unet.set_use_memory_efficient_attention_xformers(False)
 
 
 
 
 
512
 
513
- def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
514
- r"""
515
- Enable sliced attention computation.
 
 
 
 
 
 
516
 
517
- When this option is enabled, the attention module will split the input tensor in slices, to compute attention
518
- in several steps. This is useful to save some memory in exchange for a small speed decrease.
 
 
 
519
 
520
- Args:
521
- slice_size (`str` or `int`, *optional*, defaults to `"auto"`):
522
- When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
523
- a number is provided, uses as many slices as `attention_head_dim // slice_size`. In this case,
524
- `attention_head_dim` must be a multiple of `slice_size`.
525
- """
526
- if slice_size == "auto":
527
- if isinstance(self.unet.config.attention_head_dim, int):
528
- # half the attention head size is usually a good trade-off between
529
- # speed and memory
530
- slice_size = self.unet.config.attention_head_dim // 2
531
- else:
532
- # if `attention_head_dim` is a list, take the smallest head size
533
- slice_size = min(self.unet.config.attention_head_dim)
534
-
535
- self.unet.set_attention_slice(slice_size)
536
 
537
- def disable_attention_slicing(self):
538
- r"""
539
- Disable sliced attention computation. If `enable_attention_slicing` was previously invoked, this method will go
540
- back to computing attention in one step.
541
- """
542
- # set slice_size = `None` to disable `attention slicing`
543
- self.enable_attention_slicing(None)
544
 
545
- def enable_vae_slicing(self):
546
- r"""
547
- Enable sliced VAE decoding.
548
- When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
549
- steps. This is useful to save some memory and allow larger batch sizes.
550
- """
551
- self.vae.enable_slicing()
552
 
553
- def disable_vae_slicing(self):
554
- r"""
555
- Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
556
- computing decoding in one step.
557
- """
558
- self.vae.disable_slicing()
559
 
560
- def enable_sequential_cpu_offload(self):
561
- r"""
562
- Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
563
- text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
564
- `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
565
- """
566
- if is_accelerate_available():
567
- from accelerate import cpu_offload
 
 
 
568
  else:
569
- raise ImportError("Please install accelerate via `pip install accelerate`")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
570
 
571
- device = self.device
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
572
 
573
- for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]:
574
- if cpu_offloaded_model is not None:
575
- cpu_offload(cpu_offloaded_model, device)
 
 
 
 
576
 
577
  @torch.no_grad()
578
  def __call__(
@@ -676,221 +694,111 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
676
  init_image = deprecate("init_image", "0.12.0", message, take_from=kwargs)
677
  image = init_image or image
678
 
679
- if isinstance(prompt, str):
680
- batch_size = 1
681
- prompt = [prompt]
682
- elif isinstance(prompt, list):
683
- batch_size = len(prompt)
684
- else:
685
- raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
686
-
687
- if strength < 0 or strength > 1:
688
- raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
689
 
690
- if height % 8 != 0 or width % 8 != 0:
691
- raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
692
-
693
- if (callback_steps is None) or (
694
- callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
695
- ):
696
- raise ValueError(
697
- f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
698
- f" {type(callback_steps)}."
699
- )
700
-
701
- # get prompt text embeddings
702
 
 
 
 
703
  # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
704
  # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
705
  # corresponds to doing no classifier free guidance.
706
  do_classifier_free_guidance = guidance_scale > 1.0
707
- # get unconditional embeddings for classifier free guidance
708
- if negative_prompt is None:
709
- negative_prompt = [""] * batch_size
710
- elif isinstance(negative_prompt, str):
711
- negative_prompt = [negative_prompt] * batch_size
712
- if batch_size != len(negative_prompt):
713
- raise ValueError(
714
- f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
715
- f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
716
- " the batch size of `prompt`."
717
- )
718
 
719
- text_embeddings, uncond_embeddings = get_weighted_text_embeddings(
720
- pipe=self,
721
- prompt=prompt,
722
- uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
723
- max_embeddings_multiples=max_embeddings_multiples,
724
- **kwargs,
 
 
725
  )
726
- bs_embed, seq_len, _ = text_embeddings.shape
727
- text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
728
- text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
729
-
730
- if do_classifier_free_guidance:
731
- bs_embed, seq_len, _ = uncond_embeddings.shape
732
- uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
733
- uncond_embeddings = uncond_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
734
- text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
735
-
736
- # set timesteps
737
- self.scheduler.set_timesteps(num_inference_steps)
738
-
739
- latents_dtype = text_embeddings.dtype
740
- init_latents_orig = None
741
- mask = None
742
- noise = None
743
-
744
- if image is None:
745
- # get the initial random noise unless the user supplied it
746
-
747
- # Unlike in other pipelines, latents need to be generated in the target device
748
- # for 1-to-1 results reproducibility with the CompVis implementation.
749
- # However this currently doesn't work in `mps`.
750
- latents_shape = (
751
- batch_size * num_images_per_prompt,
752
- self.unet.in_channels,
753
- height // 8,
754
- width // 8,
755
- )
756
-
757
- if latents is None:
758
- if self.device.type == "mps":
759
- # randn does not exist on mps
760
- latents = torch.randn(
761
- latents_shape,
762
- generator=generator,
763
- device="cpu",
764
- dtype=latents_dtype,
765
- ).to(self.device)
766
- else:
767
- latents = torch.randn(
768
- latents_shape,
769
- generator=generator,
770
- device=self.device,
771
- dtype=latents_dtype,
772
- )
773
- else:
774
- if latents.shape != latents_shape:
775
- raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {latents_shape}")
776
- latents = latents.to(self.device)
777
-
778
- timesteps = self.scheduler.timesteps.to(self.device)
779
-
780
- # scale the initial noise by the standard deviation required by the scheduler
781
- latents = latents * self.scheduler.init_noise_sigma
782
  else:
783
- if isinstance(image, PIL.Image.Image):
784
- image = preprocess_image(image)
785
- # encode the init image into latents and scale the latents
786
- image = image.to(device=self.device, dtype=latents_dtype)
787
- init_latent_dist = self.vae.encode(image).latent_dist
788
- init_latents = init_latent_dist.sample(generator=generator)
789
- init_latents = 0.18215 * init_latents
790
- init_latents = torch.cat([init_latents] * batch_size * num_images_per_prompt, dim=0)
791
- init_latents_orig = init_latents
792
-
793
- # preprocess mask
794
- if mask_image is not None:
795
- if isinstance(mask_image, PIL.Image.Image):
796
- mask_image = preprocess_mask(mask_image)
797
- mask_image = mask_image.to(device=self.device, dtype=latents_dtype)
798
- mask = torch.cat([mask_image] * batch_size * num_images_per_prompt)
799
-
800
- # check sizes
801
- if not mask.shape == init_latents.shape:
802
- raise ValueError("The mask and image should be the same size!")
803
-
804
- # get the original timestep using init_timestep
805
- offset = self.scheduler.config.get("steps_offset", 0)
806
- init_timestep = int(num_inference_steps * strength) + offset
807
- init_timestep = min(init_timestep, num_inference_steps)
808
-
809
- timesteps = self.scheduler.timesteps[-init_timestep]
810
- timesteps = torch.tensor([timesteps] * batch_size * num_images_per_prompt, device=self.device)
811
-
812
- # add noise to latents using the timesteps
813
- if self.device.type == "mps":
814
- # randn does not exist on mps
815
- noise = torch.randn(
816
- init_latents.shape,
817
- generator=generator,
818
- device="cpu",
819
- dtype=latents_dtype,
820
- ).to(self.device)
821
- else:
822
- noise = torch.randn(
823
- init_latents.shape,
824
- generator=generator,
825
- device=self.device,
826
- dtype=latents_dtype,
827
- )
828
- latents = self.scheduler.add_noise(init_latents, noise, timesteps)
829
-
830
- t_start = max(num_inference_steps - init_timestep + offset, 0)
831
- timesteps = self.scheduler.timesteps[t_start:].to(self.device)
832
-
833
- # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
834
- # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
835
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
836
- # and should be between [0, 1]
837
- accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
838
- extra_step_kwargs = {}
839
- if accepts_eta:
840
- extra_step_kwargs["eta"] = eta
841
 
842
- for i, t in enumerate(self.progress_bar(timesteps)):
843
- # expand the latents if we are doing classifier free guidance
844
- latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
845
- latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
846
 
847
- # predict the noise residual
848
- noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
 
 
 
 
 
849
 
850
- # perform guidance
851
- if do_classifier_free_guidance:
852
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
853
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
854
 
855
- # compute the previous noisy sample x_t -> x_t-1
856
- latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
 
 
857
 
858
- if mask is not None:
859
- # masking
860
- init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
861
- latents = (init_latents_proper * mask) + (latents * (1 - mask))
862
 
863
- # call the callback, if provided
864
- if i % callback_steps == 0:
865
- if callback is not None:
866
- callback(i, t, latents)
867
- if is_cancelled_callback is not None and is_cancelled_callback():
868
- return None
869
 
870
- latents = 1 / 0.18215 * latents
871
- image = self.vae.decode(latents).sample
 
 
 
 
 
 
872
 
873
- image = (image / 2 + 0.5).clamp(0, 1)
 
874
 
875
- # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
876
- image = image.cpu().permute(0, 2, 3, 1).float().numpy()
877
-
878
- if self.safety_checker is not None:
879
- safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(
880
- self.device
881
- )
882
- image, has_nsfw_concept = self.safety_checker(
883
- images=image,
884
- clip_input=safety_checker_input.pixel_values.to(text_embeddings.dtype),
885
- )
886
- else:
887
- has_nsfw_concept = None
888
 
 
889
  if output_type == "pil":
890
  image = self.numpy_to_pil(image)
891
 
892
  if not return_dict:
893
- return (image, has_nsfw_concept)
894
 
895
  return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
896
 
@@ -910,6 +818,7 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
910
  output_type: Optional[str] = "pil",
911
  return_dict: bool = True,
912
  callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
 
913
  callback_steps: Optional[int] = 1,
914
  **kwargs,
915
  ):
@@ -957,6 +866,9 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
957
  callback (`Callable`, *optional*):
958
  A function that will be called every `callback_steps` steps during inference. The function will be
959
  called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
 
 
 
960
  callback_steps (`int`, *optional*, defaults to 1):
961
  The frequency at which the `callback` function will be called. If not specified, the callback will be
962
  called at every step.
@@ -982,6 +894,7 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
982
  output_type=output_type,
983
  return_dict=return_dict,
984
  callback=callback,
 
985
  callback_steps=callback_steps,
986
  **kwargs,
987
  )
@@ -1001,6 +914,7 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
1001
  output_type: Optional[str] = "pil",
1002
  return_dict: bool = True,
1003
  callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
 
1004
  callback_steps: Optional[int] = 1,
1005
  **kwargs,
1006
  ):
@@ -1049,6 +963,9 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
1049
  callback (`Callable`, *optional*):
1050
  A function that will be called every `callback_steps` steps during inference. The function will be
1051
  called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
 
 
 
1052
  callback_steps (`int`, *optional*, defaults to 1):
1053
  The frequency at which the `callback` function will be called. If not specified, the callback will be
1054
  called at every step.
@@ -1073,6 +990,7 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
1073
  output_type=output_type,
1074
  return_dict=return_dict,
1075
  callback=callback,
 
1076
  callback_steps=callback_steps,
1077
  **kwargs,
1078
  )
@@ -1093,6 +1011,7 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
1093
  output_type: Optional[str] = "pil",
1094
  return_dict: bool = True,
1095
  callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
 
1096
  callback_steps: Optional[int] = 1,
1097
  **kwargs,
1098
  ):
@@ -1145,6 +1064,9 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
1145
  callback (`Callable`, *optional*):
1146
  A function that will be called every `callback_steps` steps during inference. The function will be
1147
  called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
 
 
 
1148
  callback_steps (`int`, *optional*, defaults to 1):
1149
  The frequency at which the `callback` function will be called. If not specified, the callback will be
1150
  called at every step.
@@ -1170,6 +1092,7 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
1170
  output_type=output_type,
1171
  return_dict=return_dict,
1172
  callback=callback,
 
1173
  callback_steps=callback_steps,
1174
  **kwargs,
1175
  )
@@ -1184,202 +1107,3 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
1184
  return_tensors="pt",
1185
  )
1186
  text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]
1187
-
1188
- # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
1189
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
1190
- # corresponds to doing no classifier free guidance.
1191
- do_classifier_free_guidance = guidance_scale > 1.0
1192
- # get unconditional embeddings for classifier free guidance
1193
- if do_classifier_free_guidance:
1194
- max_length = text_input.input_ids.shape[-1]
1195
- uncond_input = self.tokenizer(
1196
- [""], padding="max_length", max_length=max_length, return_tensors="pt"
1197
- )
1198
- uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0]
1199
-
1200
- # For classifier free guidance, we need to do two forward passes.
1201
- # Here we concatenate the unconditional and text embeddings into a single batch
1202
- # to avoid doing two forward passes
1203
- text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
1204
-
1205
- return text_embeddings
1206
-
1207
- def slerp(self, t, v0, v1, DOT_THRESHOLD=0.9995):
1208
- """ helper function to spherically interpolate two arrays v1 v2
1209
- from https://gist.github.com/karpathy/00103b0037c5aaea32fe1da1af553355
1210
- this should be better than lerping for moving between noise spaces """
1211
-
1212
- if not isinstance(v0, np.ndarray):
1213
- inputs_are_torch = True
1214
- input_device = v0.device
1215
- v0 = v0.cpu().numpy()
1216
- v1 = v1.cpu().numpy()
1217
-
1218
- dot = np.sum(v0 * v1 / (np.linalg.norm(v0) * np.linalg.norm(v1)))
1219
- if np.abs(dot) > DOT_THRESHOLD:
1220
- v2 = (1 - t) * v0 + t * v1
1221
- else:
1222
- theta_0 = np.arccos(dot)
1223
- sin_theta_0 = np.sin(theta_0)
1224
- theta_t = theta_0 * t
1225
- sin_theta_t = np.sin(theta_t)
1226
- s0 = np.sin(theta_0 - theta_t) / sin_theta_0
1227
- s1 = sin_theta_t / sin_theta_0
1228
- v2 = s0 * v0 + s1 * v1
1229
-
1230
- if inputs_are_torch:
1231
- v2 = torch.from_numpy(v2).to(input_device)
1232
-
1233
- return v2
1234
-
1235
- def lerp_between_prompts(self, first_prompt, second_prompt, seed = None, length = 10, save=False, guidance_scale: Optional[float] = 7.5, **kwargs):
1236
- first_embedding = self.get_text_latent_space(first_prompt)
1237
- second_embedding = self.get_text_latent_space(second_prompt)
1238
- if not seed:
1239
- seed = random.randint(0, sys.maxsize)
1240
- generator = torch.Generator(self.device)
1241
- generator.manual_seed(seed)
1242
- generator_state = generator.get_state()
1243
- lerp_embed_points = []
1244
- for i in range(length):
1245
- weight = i / length
1246
- tensor_lerp = torch.lerp(first_embedding, second_embedding, weight)
1247
- lerp_embed_points.append(tensor_lerp)
1248
- images = []
1249
- for idx, latent_point in enumerate(lerp_embed_points):
1250
- generator.set_state(generator_state)
1251
- image = self.diffuse_from_inits(latent_point, **kwargs)["image"][0]
1252
- images.append(image)
1253
- if save:
1254
- image.save(f"{first_prompt}-{second_prompt}-{idx:02d}.png", "PNG")
1255
- return {"images": images, "latent_points": lerp_embed_points,"generator_state": generator_state}
1256
-
1257
- def slerp_through_seeds(self,
1258
- prompt,
1259
- height: Optional[int] = 512,
1260
- width: Optional[int] = 512,
1261
- save = False,
1262
- seed = None, steps = 10, **kwargs):
1263
-
1264
- if not seed:
1265
- seed = random.randint(0, sys.maxsize)
1266
- generator = torch.Generator(self.device)
1267
- generator.manual_seed(seed)
1268
- init_start = torch.randn(
1269
- (1, self.unet.in_channels, height // 8, width // 8),
1270
- generator = generator, device = self.device)
1271
- init_end = torch.randn(
1272
- (1, self.unet.in_channels, height // 8, width // 8),
1273
- generator = generator, device = self.device)
1274
- generator_state = generator.get_state()
1275
- slerp_embed_points = []
1276
- # weight from 0 to 1/(steps - 1), add init_end specifically so that we
1277
- # have len(images) = steps
1278
- for i in range(steps - 1):
1279
- weight = i / steps
1280
- tensor_slerp = self.slerp(weight, init_start, init_end)
1281
- slerp_embed_points.append(tensor_slerp)
1282
- slerp_embed_points.append(init_end)
1283
- images = []
1284
- embed_point = self.get_text_latent_space(prompt)
1285
- for idx, noise_point in enumerate(slerp_embed_points):
1286
- generator.set_state(generator_state)
1287
- image = self.diffuse_from_inits(embed_point, init = noise_point, **kwargs)["image"][0]
1288
- images.append(image)
1289
- if save:
1290
- image.save(f"{seed}-{idx:02d}.png", "PNG")
1291
- return {"images": images, "noise_samples": slerp_embed_points,"generator_state": generator_state}
1292
-
1293
- @torch.no_grad()
1294
- def diffuse_from_inits(self, text_embeddings,
1295
- init = None,
1296
- height: Optional[int] = 512,
1297
- width: Optional[int] = 512,
1298
- num_inference_steps: Optional[int] = 50,
1299
- guidance_scale: Optional[float] = 7.5,
1300
- eta: Optional[float] = 0.0,
1301
- generator: Optional[torch.Generator] = None,
1302
- output_type: Optional[str] = "pil",
1303
- **kwargs,):
1304
-
1305
- batch_size = 1
1306
-
1307
- if generator == None:
1308
- generator = torch.Generator("cuda")
1309
- generator_state = generator.get_state()
1310
- # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
1311
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
1312
- # corresponds to doing no classifier free guidance.
1313
- do_classifier_free_guidance = guidance_scale > 1.0
1314
- # get the intial random noise
1315
- latents = init if init is not None else torch.randn(
1316
- (batch_size, self.unet.in_channels, height // 8, width // 8),
1317
- generator=generator,
1318
- device=self.device,)
1319
-
1320
- # set timesteps
1321
- accepts_offset = "offset" in set(inspect.signature(self.scheduler.set_timesteps).parameters.keys())
1322
- extra_set_kwargs = {}
1323
- if accepts_offset:
1324
- extra_set_kwargs["offset"] = 1
1325
-
1326
- self.scheduler.set_timesteps(num_inference_steps, **extra_set_kwargs)
1327
-
1328
- # if we use LMSDiscreteScheduler, let's make sure latents are mulitplied by sigmas
1329
- if isinstance(self.scheduler, LMSDiscreteScheduler):
1330
- latents = latents * self.scheduler.sigmas[0]
1331
-
1332
- # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
1333
- # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
1334
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
1335
- # and should be between [0, 1]
1336
- accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
1337
- extra_step_kwargs = {}
1338
- if accepts_eta:
1339
- extra_step_kwargs["eta"] = eta
1340
-
1341
- for i, t in tqdm(enumerate(self.scheduler.timesteps)):
1342
- # expand the latents if we are doing classifier free guidance
1343
- latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
1344
- if isinstance(self.scheduler, LMSDiscreteScheduler):
1345
- sigma = self.scheduler.sigmas[i]
1346
- latent_model_input = latent_model_input / ((sigma**2 + 1) ** 0.5)
1347
-
1348
- # predict the noise residual
1349
- noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
1350
-
1351
- # perform guidance
1352
- if do_classifier_free_guidance:
1353
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
1354
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
1355
-
1356
- # compute the previous noisy sample x_t -> x_t-1
1357
- if isinstance(self.scheduler, LMSDiscreteScheduler):
1358
- latents = self.scheduler.step(noise_pred, i, latents, **extra_step_kwargs).prev_sample
1359
- else:
1360
- latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
1361
-
1362
- # scale and decode the image latents with vae
1363
- latents = 1 / 0.18215 * latents
1364
- image = self.vae.decode(latents)
1365
-
1366
- image = (image / 2 + 0.5).clamp(0, 1)
1367
- image = image.cpu().permute(0, 2, 3, 1).numpy()
1368
-
1369
- if output_type == "pil":
1370
- image = self.numpy_to_pil(image)
1371
-
1372
- return {"image": image, "generator_state": generator_state}
1373
-
1374
- def variation(self, text_embeddings, generator_state, variation_magnitude = 100, **kwargs):
1375
- # random vector to move in latent space
1376
- rand_t = (torch.rand(text_embeddings.shape, device = self.device) * 2) - 1
1377
- rand_mag = torch.sum(torch.abs(rand_t)) / variation_magnitude
1378
- scaled_rand_t = rand_t / rand_mag
1379
- variation_embedding = text_embeddings + scaled_rand_t
1380
-
1381
- generator = torch.Generator("cuda")
1382
- generator.set_state(generator_state)
1383
- result = self.diffuse_from_inits(variation_embedding, generator=generator, **kwargs)
1384
- result.update({"latent_point": variation_embedding})
1385
- return result
 
9
  from tqdm.auto import tqdm
10
 
11
  import PIL
12
+ from diffusers import SchedulerMixin, StableDiffusionPipeline
13
  from diffusers.models import AutoencoderKL, UNet2DConditionModel
14
+ from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput, StableDiffusionSafetyChecker
15
+ from diffusers.utils import PIL_INTERPOLATION, deprecate, logging
 
 
 
 
 
 
16
  from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
17
 
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
20
 
21
  re_attention = re.compile(
 
124
  return res
125
 
126
 
127
+ def get_prompts_with_weights(pipe: StableDiffusionPipeline, prompt: List[str], max_length: int):
128
  r"""
129
  Tokenize a list of prompts and return its tokens with weights of each token.
130
 
 
185
 
186
 
187
  def get_unweighted_text_embeddings(
188
+ pipe: StableDiffusionPipeline,
189
  text_input: torch.Tensor,
190
  chunk_length: int,
191
  no_boseos_middle: Optional[bool] = True,
 
225
 
226
 
227
  def get_weighted_text_embeddings(
228
+ pipe: StableDiffusionPipeline,
229
  prompt: Union[str, List[str]],
230
  uncond_prompt: Optional[Union[str, List[str]]] = None,
231
+ max_embeddings_multiples: Optional[int] = 3,
232
  no_boseos_middle: Optional[bool] = False,
233
  skip_parsing: Optional[bool] = False,
234
  skip_weighting: Optional[bool] = False,
 
242
  Also, to regularize of the embedding, the weighted embedding would be scaled to preserve the original mean.
243
 
244
  Args:
245
+ pipe (`StableDiffusionPipeline`):
246
  Pipe to provide access to the tokenizer and the text encoder.
247
  prompt (`str` or `List[str]`):
248
  The prompt or prompts to guide the image generation.
249
  uncond_prompt (`str` or `List[str]`):
250
  The unconditional prompt or prompts for guide the image generation. If unconditional prompt
251
  is provided, the embeddings of prompt and uncond_prompt are concatenated.
252
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
253
  The max multiple length of prompt embeddings compared to the max output length of text encoder.
254
  no_boseos_middle (`bool`, *optional*, defaults to `False`):
255
  If the length of text token is multiples of the capacity of text encoder, whether reserve the starting and
 
365
  return 2.0 * image - 1.0
366
 
367
 
368
+ def preprocess_mask(mask, scale_factor=8):
369
  mask = mask.convert("L")
370
  w, h = mask.size
371
  w, h = map(lambda x: x - x % 32, (w, h)) # resize to integer multiple of 32
372
+ mask = mask.resize((w // scale_factor, h // scale_factor), resample=PIL_INTERPOLATION["nearest"])
373
  mask = np.array(mask).astype(np.float32) / 255.0
374
  mask = np.tile(mask, (4, 1, 1))
375
  mask = mask[None].transpose(0, 1, 2, 3) # what does this step do?
 
378
  return mask
379
 
380
 
381
+ class StableDiffusionLongPromptWeightingPipeline(StableDiffusionPipeline):
382
  r"""
383
  Pipeline for text-to-image generation using Stable Diffusion without tokens length limit, and support parsing
384
  weighting in prompt.
 
413
  text_encoder: CLIPTextModel,
414
  tokenizer: CLIPTokenizer,
415
  unet: UNet2DConditionModel,
416
+ scheduler: SchedulerMixin,
417
  safety_checker: StableDiffusionSafetyChecker,
418
  feature_extractor: CLIPFeatureExtractor,
419
+ requires_safety_checker: bool = True,
420
  ):
421
+ super().__init__(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
422
  vae=vae,
423
  text_encoder=text_encoder,
424
  tokenizer=tokenizer,
 
426
  scheduler=scheduler,
427
  safety_checker=safety_checker,
428
  feature_extractor=feature_extractor,
429
+ requires_safety_checker=requires_safety_checker,
430
  )
431
 
432
+ def _encode_prompt(
433
+ self,
434
+ prompt,
435
+ device,
436
+ num_images_per_prompt,
437
+ do_classifier_free_guidance,
438
+ negative_prompt,
439
+ max_embeddings_multiples,
440
+ ):
441
  r"""
442
+ Encodes the prompt into text encoder hidden states.
 
 
 
443
 
444
+ Args:
445
+ prompt (`str` or `list(int)`):
446
+ prompt to be encoded
447
+ device: (`torch.device`):
448
+ torch device
449
+ num_images_per_prompt (`int`):
450
+ number of images that should be generated per prompt
451
+ do_classifier_free_guidance (`bool`):
452
+ whether to use classifier free guidance or not
453
+ negative_prompt (`str` or `List[str]`):
454
+ The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
455
+ if `guidance_scale` is less than `1`).
456
+ max_embeddings_multiples (`int`, *optional*, defaults to `3`):
457
+ The max multiple length of prompt embeddings compared to the max output length of text encoder.
458
  """
459
+ batch_size = len(prompt) if isinstance(prompt, list) else 1
460
 
461
+ if negative_prompt is None:
462
+ negative_prompt = [""] * batch_size
463
+ elif isinstance(negative_prompt, str):
464
+ negative_prompt = [negative_prompt] * batch_size
465
+ if batch_size != len(negative_prompt):
466
+ raise ValueError(
467
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
468
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
469
+ " the batch size of `prompt`."
470
+ )
471
 
472
+ text_embeddings, uncond_embeddings = get_weighted_text_embeddings(
473
+ pipe=self,
474
+ prompt=prompt,
475
+ uncond_prompt=negative_prompt if do_classifier_free_guidance else None,
476
+ max_embeddings_multiples=max_embeddings_multiples,
477
+ )
478
+ bs_embed, seq_len, _ = text_embeddings.shape
479
+ text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
480
+ text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
481
 
482
+ if do_classifier_free_guidance:
483
+ bs_embed, seq_len, _ = uncond_embeddings.shape
484
+ uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
485
+ uncond_embeddings = uncond_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
486
+ text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
487
 
488
+ return text_embeddings
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
489
 
490
+ def check_inputs(self, prompt, height, width, strength, callback_steps):
491
+ if not isinstance(prompt, str) and not isinstance(prompt, list):
492
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
 
 
 
 
493
 
494
+ if strength < 0 or strength > 1:
495
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
 
 
 
 
 
496
 
497
+ if height % 8 != 0 or width % 8 != 0:
498
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
 
 
 
 
499
 
500
+ if (callback_steps is None) or (
501
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
502
+ ):
503
+ raise ValueError(
504
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
505
+ f" {type(callback_steps)}."
506
+ )
507
+
508
+ def get_timesteps(self, num_inference_steps, strength, device, is_text2img):
509
+ if is_text2img:
510
+ return self.scheduler.timesteps.to(device), num_inference_steps
511
  else:
512
+ # get the original timestep using init_timestep
513
+ offset = self.scheduler.config.get("steps_offset", 0)
514
+ init_timestep = int(num_inference_steps * strength) + offset
515
+ init_timestep = min(init_timestep, num_inference_steps)
516
+
517
+ t_start = max(num_inference_steps - init_timestep + offset, 0)
518
+ timesteps = self.scheduler.timesteps[t_start:].to(device)
519
+ return timesteps, num_inference_steps - t_start
520
+
521
+ def run_safety_checker(self, image, device, dtype):
522
+ if self.safety_checker is not None:
523
+ safety_checker_input = self.feature_extractor(self.numpy_to_pil(image), return_tensors="pt").to(device)
524
+ image, has_nsfw_concept = self.safety_checker(
525
+ images=image, clip_input=safety_checker_input.pixel_values.to(dtype)
526
+ )
527
+ else:
528
+ has_nsfw_concept = None
529
+ return image, has_nsfw_concept
530
 
531
+ def decode_latents(self, latents):
532
+ latents = 1 / 0.18215 * latents
533
+ image = self.vae.decode(latents).sample
534
+ image = (image / 2 + 0.5).clamp(0, 1)
535
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16
536
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
537
+ return image
538
+
539
+ def prepare_extra_step_kwargs(self, generator, eta):
540
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
541
+ # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
542
+ # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
543
+ # and should be between [0, 1]
544
+
545
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
546
+ extra_step_kwargs = {}
547
+ if accepts_eta:
548
+ extra_step_kwargs["eta"] = eta
549
+
550
+ # check if the scheduler accepts generator
551
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
552
+ if accepts_generator:
553
+ extra_step_kwargs["generator"] = generator
554
+ return extra_step_kwargs
555
+
556
+ def prepare_latents(self, image, timestep, batch_size, height, width, dtype, device, generator, latents=None):
557
+ if image is None:
558
+ shape = (
559
+ batch_size,
560
+ self.unet.in_channels,
561
+ height // self.vae_scale_factor,
562
+ width // self.vae_scale_factor,
563
+ )
564
+
565
+ if latents is None:
566
+ if device.type == "mps":
567
+ # randn does not work reproducibly on mps
568
+ latents = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
569
+ else:
570
+ latents = torch.randn(shape, generator=generator, device=device, dtype=dtype)
571
+ else:
572
+ if latents.shape != shape:
573
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
574
+ latents = latents.to(device)
575
+
576
+ # scale the initial noise by the standard deviation required by the scheduler
577
+ latents = latents * self.scheduler.init_noise_sigma
578
+ return latents, None, None
579
+ else:
580
+ init_latent_dist = self.vae.encode(image).latent_dist
581
+ init_latents = init_latent_dist.sample(generator=generator)
582
+ init_latents = 0.18215 * init_latents
583
+ init_latents = torch.cat([init_latents] * batch_size, dim=0)
584
+ init_latents_orig = init_latents
585
+ shape = init_latents.shape
586
 
587
+ # add noise to latents using the timesteps
588
+ if device.type == "mps":
589
+ noise = torch.randn(shape, generator=generator, device="cpu", dtype=dtype).to(device)
590
+ else:
591
+ noise = torch.randn(shape, generator=generator, device=device, dtype=dtype)
592
+ latents = self.scheduler.add_noise(init_latents, noise, timestep)
593
+ return latents, init_latents_orig, noise
594
 
595
  @torch.no_grad()
596
  def __call__(
 
694
  init_image = deprecate("init_image", "0.12.0", message, take_from=kwargs)
695
  image = init_image or image
696
 
697
+ # 0. Default height and width to unet
698
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
699
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
 
 
 
 
 
 
 
700
 
701
+ # 1. Check inputs. Raise error if not correct
702
+ self.check_inputs(prompt, height, width, strength, callback_steps)
 
 
 
 
 
 
 
 
 
 
703
 
704
+ # 2. Define call parameters
705
+ batch_size = 1 if isinstance(prompt, str) else len(prompt)
706
+ device = self._execution_device
707
  # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
708
  # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
709
  # corresponds to doing no classifier free guidance.
710
  do_classifier_free_guidance = guidance_scale > 1.0
 
 
 
 
 
 
 
 
 
 
 
711
 
712
+ # 3. Encode input prompt
713
+ text_embeddings = self._encode_prompt(
714
+ prompt,
715
+ device,
716
+ num_images_per_prompt,
717
+ do_classifier_free_guidance,
718
+ negative_prompt,
719
+ max_embeddings_multiples,
720
  )
721
+ dtype = text_embeddings.dtype
722
+
723
+ # 4. Preprocess image and mask
724
+ if isinstance(image, PIL.Image.Image):
725
+ image = preprocess_image(image)
726
+ if image is not None:
727
+ image = image.to(device=self.device, dtype=dtype)
728
+ if isinstance(mask_image, PIL.Image.Image):
729
+ mask_image = preprocess_mask(mask_image, self.vae_scale_factor)
730
+ if mask_image is not None:
731
+ mask = mask_image.to(device=self.device, dtype=dtype)
732
+ mask = torch.cat([mask] * batch_size * num_images_per_prompt)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
733
  else:
734
+ mask = None
735
+
736
+ # 5. set timesteps
737
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
738
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device, image is None)
739
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
740
+
741
+ # 6. Prepare latent variables
742
+ latents, init_latents_orig, noise = self.prepare_latents(
743
+ image,
744
+ latent_timestep,
745
+ batch_size * num_images_per_prompt,
746
+ height,
747
+ width,
748
+ dtype,
749
+ device,
750
+ generator,
751
+ latents,
752
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
753
 
754
+ # 7. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
755
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
 
 
756
 
757
+ # 8. Denoising loop
758
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
759
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
760
+ for i, t in enumerate(timesteps):
761
+ # expand the latents if we are doing classifier free guidance
762
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
763
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
764
 
765
+ # predict the noise residual
766
+ noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
 
 
767
 
768
+ # perform guidance
769
+ if do_classifier_free_guidance:
770
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
771
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
772
 
773
+ # compute the previous noisy sample x_t -> x_t-1
774
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
 
 
775
 
776
+ if mask is not None:
777
+ # masking
778
+ init_latents_proper = self.scheduler.add_noise(init_latents_orig, noise, torch.tensor([t]))
779
+ latents = (init_latents_proper * mask) + (latents * (1 - mask))
 
 
780
 
781
+ # call the callback, if provided
782
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
783
+ progress_bar.update()
784
+ if i % callback_steps == 0:
785
+ if callback is not None:
786
+ callback(i, t, latents)
787
+ if is_cancelled_callback is not None and is_cancelled_callback():
788
+ return None
789
 
790
+ # 9. Post-processing
791
+ image = self.decode_latents(latents)
792
 
793
+ # 10. Run safety checker
794
+ image, has_nsfw_concept = self.run_safety_checker(image, device, text_embeddings.dtype)
 
 
 
 
 
 
 
 
 
 
 
795
 
796
+ # 11. Convert to PIL
797
  if output_type == "pil":
798
  image = self.numpy_to_pil(image)
799
 
800
  if not return_dict:
801
+ return image, has_nsfw_concept
802
 
803
  return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept)
804
 
 
818
  output_type: Optional[str] = "pil",
819
  return_dict: bool = True,
820
  callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
821
+ is_cancelled_callback: Optional[Callable[[], bool]] = None,
822
  callback_steps: Optional[int] = 1,
823
  **kwargs,
824
  ):
 
866
  callback (`Callable`, *optional*):
867
  A function that will be called every `callback_steps` steps during inference. The function will be
868
  called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
869
+ is_cancelled_callback (`Callable`, *optional*):
870
+ A function that will be called every `callback_steps` steps during inference. If the function returns
871
+ `True`, the inference will be cancelled.
872
  callback_steps (`int`, *optional*, defaults to 1):
873
  The frequency at which the `callback` function will be called. If not specified, the callback will be
874
  called at every step.
 
894
  output_type=output_type,
895
  return_dict=return_dict,
896
  callback=callback,
897
+ is_cancelled_callback=is_cancelled_callback,
898
  callback_steps=callback_steps,
899
  **kwargs,
900
  )
 
914
  output_type: Optional[str] = "pil",
915
  return_dict: bool = True,
916
  callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
917
+ is_cancelled_callback: Optional[Callable[[], bool]] = None,
918
  callback_steps: Optional[int] = 1,
919
  **kwargs,
920
  ):
 
963
  callback (`Callable`, *optional*):
964
  A function that will be called every `callback_steps` steps during inference. The function will be
965
  called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
966
+ is_cancelled_callback (`Callable`, *optional*):
967
+ A function that will be called every `callback_steps` steps during inference. If the function returns
968
+ `True`, the inference will be cancelled.
969
  callback_steps (`int`, *optional*, defaults to 1):
970
  The frequency at which the `callback` function will be called. If not specified, the callback will be
971
  called at every step.
 
990
  output_type=output_type,
991
  return_dict=return_dict,
992
  callback=callback,
993
+ is_cancelled_callback=is_cancelled_callback,
994
  callback_steps=callback_steps,
995
  **kwargs,
996
  )
 
1011
  output_type: Optional[str] = "pil",
1012
  return_dict: bool = True,
1013
  callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
1014
+ is_cancelled_callback: Optional[Callable[[], bool]] = None,
1015
  callback_steps: Optional[int] = 1,
1016
  **kwargs,
1017
  ):
 
1064
  callback (`Callable`, *optional*):
1065
  A function that will be called every `callback_steps` steps during inference. The function will be
1066
  called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
1067
+ is_cancelled_callback (`Callable`, *optional*):
1068
+ A function that will be called every `callback_steps` steps during inference. If the function returns
1069
+ `True`, the inference will be cancelled.
1070
  callback_steps (`int`, *optional*, defaults to 1):
1071
  The frequency at which the `callback` function will be called. If not specified, the callback will be
1072
  called at every step.
 
1092
  output_type=output_type,
1093
  return_dict=return_dict,
1094
  callback=callback,
1095
+ is_cancelled_callback=is_cancelled_callback,
1096
  callback_steps=callback_steps,
1097
  **kwargs,
1098
  )
 
1107
  return_tensors="pt",
1108
  )
1109
  text_embeddings = self.text_encoder(text_input.input_ids.to(self.device))[0]