fix broken num_images_per_prompt
Browse files- pipeline.py +12 -7
pipeline.py
CHANGED
@@ -324,13 +324,9 @@ def get_weighted_text_embeddings(
|
|
324 |
uncond_embeddings *= uncond_weights.unsqueeze(-1)
|
325 |
uncond_embeddings *= previous_mean / uncond_embeddings.mean(axis=[-2, -1])
|
326 |
|
327 |
-
# For classifier free guidance, we need to do two forward passes.
|
328 |
-
# Here we concatenate the unconditional and text embeddings into a single batch
|
329 |
-
# to avoid doing two forward passes
|
330 |
if uncond_prompt is not None:
|
331 |
-
text_embeddings
|
332 |
-
|
333 |
-
return text_embeddings
|
334 |
|
335 |
|
336 |
def preprocess_image(image):
|
@@ -598,13 +594,22 @@ class StableDiffusionLongPromptWeightingPipeline(DiffusionPipeline):
|
|
598 |
else:
|
599 |
uncond_tokens = negative_prompt
|
600 |
|
601 |
-
text_embeddings = get_weighted_text_embeddings(
|
602 |
pipe=self,
|
603 |
prompt=prompt,
|
604 |
uncond_prompt=uncond_tokens if do_classifier_free_guidance else None,
|
605 |
max_embeddings_multiples=max_embeddings_multiples,
|
606 |
**kwargs
|
607 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
608 |
|
609 |
# set timesteps
|
610 |
self.scheduler.set_timesteps(num_inference_steps)
|
|
|
324 |
uncond_embeddings *= uncond_weights.unsqueeze(-1)
|
325 |
uncond_embeddings *= previous_mean / uncond_embeddings.mean(axis=[-2, -1])
|
326 |
|
|
|
|
|
|
|
327 |
if uncond_prompt is not None:
|
328 |
+
return text_embeddings, uncond_embeddings
|
329 |
+
return text_embeddings, None
|
|
|
330 |
|
331 |
|
332 |
def preprocess_image(image):
|
|
|
594 |
else:
|
595 |
uncond_tokens = negative_prompt
|
596 |
|
597 |
+
text_embeddings, uncond_embeddings = get_weighted_text_embeddings(
|
598 |
pipe=self,
|
599 |
prompt=prompt,
|
600 |
uncond_prompt=uncond_tokens if do_classifier_free_guidance else None,
|
601 |
max_embeddings_multiples=max_embeddings_multiples,
|
602 |
**kwargs
|
603 |
)
|
604 |
+
bs_embed, seq_len, _ = text_embeddings.shape
|
605 |
+
text_embeddings = text_embeddings.repeat(1, num_images_per_prompt, 1)
|
606 |
+
text_embeddings = text_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
607 |
+
|
608 |
+
if do_classifier_free_guidance:
|
609 |
+
bs_embed, seq_len, _ = uncond_embeddings.shape
|
610 |
+
uncond_embeddings = uncond_embeddings.repeat(1, num_images_per_prompt, 1)
|
611 |
+
uncond_embeddings = uncond_embeddings.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
612 |
+
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
613 |
|
614 |
# set timesteps
|
615 |
self.scheduler.set_timesteps(num_inference_steps)
|