Update pipeline.py
Browse files- pipeline.py +13 -19
pipeline.py
CHANGED
@@ -1176,11 +1176,16 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
|
|
1176 |
lora_scale=text_encoder_lora_scale,
|
1177 |
clip_skip=clip_skip,
|
1178 |
)
|
|
|
|
|
|
|
|
|
1179 |
# For classifier free guidance, we need to do two forward passes.
|
1180 |
# Here we concatenate the unconditional and text embeddings into a single batch
|
1181 |
# to avoid doing two forward passes
|
1182 |
if do_classifier_free_guidance:
|
1183 |
-
|
|
|
1184 |
|
1185 |
if ip_adapter_image is not None:
|
1186 |
output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
|
@@ -1402,6 +1407,8 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
|
|
1402 |
|
1403 |
# Denoising loop
|
1404 |
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
|
|
|
|
1405 |
with self.progress_bar(total=len(timesteps)) as progress_bar:
|
1406 |
for i, t in enumerate(timesteps):
|
1407 |
noise_pred_uncond_sum = torch.zeros_like(latents).to(device).to(dtype=torch.float16)
|
@@ -1421,9 +1428,6 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
|
|
1421 |
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
1422 |
|
1423 |
if self.controlnet != None and i < int(control_end*num_inference_steps):
|
1424 |
-
|
1425 |
-
torch.cuda.synchronize() # Synchronize GPU
|
1426 |
-
control_start = time.time()
|
1427 |
|
1428 |
current_context_conditioning_frames = conditioning_frames[current_context_indexes, :, :, :]
|
1429 |
current_context_conditioning_frames = torch.cat([current_context_conditioning_frames] * 2) if do_classifier_free_guidance else current_context_conditioning_frames
|
@@ -1454,6 +1458,10 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
|
|
1454 |
)
|
1455 |
|
1456 |
|
|
|
|
|
|
|
|
|
1457 |
down_block_res_samples, mid_block_res_sample = self.controlnet(
|
1458 |
control_model_input,
|
1459 |
t,
|
@@ -1464,12 +1472,11 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
|
|
1464 |
return_dict=False,
|
1465 |
)
|
1466 |
|
1467 |
-
unet_start = time.time()
|
1468 |
# predict the noise residual with the added controlnet residuals
|
1469 |
noise_pred = self.unet(
|
1470 |
latent_model_input,
|
1471 |
t,
|
1472 |
-
encoder_hidden_states=prompt_embeds,
|
1473 |
cross_attention_kwargs=cross_attention_kwargs,
|
1474 |
added_cond_kwargs=added_cond_kwargs,
|
1475 |
down_block_additional_residuals=down_block_res_samples,
|
@@ -1478,8 +1485,6 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
|
|
1478 |
|
1479 |
else:
|
1480 |
# predict the noise residual without contorlnet
|
1481 |
-
torch.cuda.synchronize()
|
1482 |
-
unet_start = time.time()
|
1483 |
noise_pred = self.unet(
|
1484 |
latent_model_input,
|
1485 |
t,
|
@@ -1489,19 +1494,8 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
|
|
1489 |
).sample
|
1490 |
|
1491 |
if do_classifier_free_guidance:
|
1492 |
-
# Start timing for overall guidance process
|
1493 |
-
torch.cuda.synchronize() # Synchronize GPU before starting timing
|
1494 |
-
start_guidance_time = time.time()
|
1495 |
-
|
1496 |
-
# Timing for chunk operation
|
1497 |
-
torch.cuda.synchronize() # Synchronize GPU before chunking
|
1498 |
-
time_chunk_start = time.time()
|
1499 |
|
1500 |
noise_pred_uncond, noise_pred_text = torch.chunk(noise_pred, 2, dim=0)
|
1501 |
-
|
1502 |
-
# Timing for batch addition and latent counter increment
|
1503 |
-
torch.cuda.synchronize() # Synchronize GPU before batch addition
|
1504 |
-
time_batch_addition_start = time.time()
|
1505 |
|
1506 |
# Perform batch addition
|
1507 |
noise_pred_uncond_sum[..., current_context_indexes, :, :] += noise_pred_uncond
|
|
|
1176 |
lora_scale=text_encoder_lora_scale,
|
1177 |
clip_skip=clip_skip,
|
1178 |
)
|
1179 |
+
# print promtp embed shape
|
1180 |
+
print("prompt_embeds shape after encoding")
|
1181 |
+
print(prompt_embeds.shape)
|
1182 |
+
|
1183 |
# For classifier free guidance, we need to do two forward passes.
|
1184 |
# Here we concatenate the unconditional and text embeddings into a single batch
|
1185 |
# to avoid doing two forward passes
|
1186 |
if do_classifier_free_guidance:
|
1187 |
+
# concatenate negative prompt embeddings with prompt embeddings on a new dimension after the first batch dimension
|
1188 |
+
prompt_embeds = torch.stack([negative_prompt_embeds, prompt_embeds], dim=1)
|
1189 |
|
1190 |
if ip_adapter_image is not None:
|
1191 |
output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
|
|
|
1407 |
|
1408 |
# Denoising loop
|
1409 |
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
1410 |
+
# get the number of prompt from the 1st dimension of prompt_embeds
|
1411 |
+
num_prompts = prompt_embeds.shape[0]
|
1412 |
with self.progress_bar(total=len(timesteps)) as progress_bar:
|
1413 |
for i, t in enumerate(timesteps):
|
1414 |
noise_pred_uncond_sum = torch.zeros_like(latents).to(device).to(dtype=torch.float16)
|
|
|
1428 |
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
1429 |
|
1430 |
if self.controlnet != None and i < int(control_end*num_inference_steps):
|
|
|
|
|
|
|
1431 |
|
1432 |
current_context_conditioning_frames = conditioning_frames[current_context_indexes, :, :, :]
|
1433 |
current_context_conditioning_frames = torch.cat([current_context_conditioning_frames] * 2) if do_classifier_free_guidance else current_context_conditioning_frames
|
|
|
1458 |
)
|
1459 |
|
1460 |
|
1461 |
+
# get the current prompt index based on the current context position (for blending between multiple prompts)
|
1462 |
+
context_position = current_context_indexes[0] % context_size
|
1463 |
+
current_prompt_index = int(context_position / (context_size / num_prompts))
|
1464 |
+
|
1465 |
down_block_res_samples, mid_block_res_sample = self.controlnet(
|
1466 |
control_model_input,
|
1467 |
t,
|
|
|
1472 |
return_dict=False,
|
1473 |
)
|
1474 |
|
|
|
1475 |
# predict the noise residual with the added controlnet residuals
|
1476 |
noise_pred = self.unet(
|
1477 |
latent_model_input,
|
1478 |
t,
|
1479 |
+
encoder_hidden_states=prompt_embeds[current_prompt_index],
|
1480 |
cross_attention_kwargs=cross_attention_kwargs,
|
1481 |
added_cond_kwargs=added_cond_kwargs,
|
1482 |
down_block_additional_residuals=down_block_res_samples,
|
|
|
1485 |
|
1486 |
else:
|
1487 |
# predict the noise residual without contorlnet
|
|
|
|
|
1488 |
noise_pred = self.unet(
|
1489 |
latent_model_input,
|
1490 |
t,
|
|
|
1494 |
).sample
|
1495 |
|
1496 |
if do_classifier_free_guidance:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1497 |
|
1498 |
noise_pred_uncond, noise_pred_text = torch.chunk(noise_pred, 2, dim=0)
|
|
|
|
|
|
|
|
|
1499 |
|
1500 |
# Perform batch addition
|
1501 |
noise_pred_uncond_sum[..., current_context_indexes, :, :] += noise_pred_uncond
|