smoothieAI commited on
Commit
127315d
·
verified ·
1 Parent(s): 14c4be1

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +13 -19
pipeline.py CHANGED
@@ -1176,11 +1176,16 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
1176
  lora_scale=text_encoder_lora_scale,
1177
  clip_skip=clip_skip,
1178
  )
 
 
 
 
1179
  # For classifier free guidance, we need to do two forward passes.
1180
  # Here we concatenate the unconditional and text embeddings into a single batch
1181
  # to avoid doing two forward passes
1182
  if do_classifier_free_guidance:
1183
- prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
 
1184
 
1185
  if ip_adapter_image is not None:
1186
  output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
@@ -1402,6 +1407,8 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
1402
 
1403
  # Denoising loop
1404
  num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
 
 
1405
  with self.progress_bar(total=len(timesteps)) as progress_bar:
1406
  for i, t in enumerate(timesteps):
1407
  noise_pred_uncond_sum = torch.zeros_like(latents).to(device).to(dtype=torch.float16)
@@ -1421,9 +1428,6 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
1421
  latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1422
 
1423
  if self.controlnet != None and i < int(control_end*num_inference_steps):
1424
-
1425
- torch.cuda.synchronize() # Synchronize GPU
1426
- control_start = time.time()
1427
 
1428
  current_context_conditioning_frames = conditioning_frames[current_context_indexes, :, :, :]
1429
  current_context_conditioning_frames = torch.cat([current_context_conditioning_frames] * 2) if do_classifier_free_guidance else current_context_conditioning_frames
@@ -1454,6 +1458,10 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
1454
  )
1455
 
1456
 
 
 
 
 
1457
  down_block_res_samples, mid_block_res_sample = self.controlnet(
1458
  control_model_input,
1459
  t,
@@ -1464,12 +1472,11 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
1464
  return_dict=False,
1465
  )
1466
 
1467
- unet_start = time.time()
1468
  # predict the noise residual with the added controlnet residuals
1469
  noise_pred = self.unet(
1470
  latent_model_input,
1471
  t,
1472
- encoder_hidden_states=prompt_embeds,
1473
  cross_attention_kwargs=cross_attention_kwargs,
1474
  added_cond_kwargs=added_cond_kwargs,
1475
  down_block_additional_residuals=down_block_res_samples,
@@ -1478,8 +1485,6 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
1478
 
1479
  else:
1480
  # predict the noise residual without contorlnet
1481
- torch.cuda.synchronize()
1482
- unet_start = time.time()
1483
  noise_pred = self.unet(
1484
  latent_model_input,
1485
  t,
@@ -1489,19 +1494,8 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
1489
  ).sample
1490
 
1491
  if do_classifier_free_guidance:
1492
- # Start timing for overall guidance process
1493
- torch.cuda.synchronize() # Synchronize GPU before starting timing
1494
- start_guidance_time = time.time()
1495
-
1496
- # Timing for chunk operation
1497
- torch.cuda.synchronize() # Synchronize GPU before chunking
1498
- time_chunk_start = time.time()
1499
 
1500
  noise_pred_uncond, noise_pred_text = torch.chunk(noise_pred, 2, dim=0)
1501
-
1502
- # Timing for batch addition and latent counter increment
1503
- torch.cuda.synchronize() # Synchronize GPU before batch addition
1504
- time_batch_addition_start = time.time()
1505
 
1506
  # Perform batch addition
1507
  noise_pred_uncond_sum[..., current_context_indexes, :, :] += noise_pred_uncond
 
1176
  lora_scale=text_encoder_lora_scale,
1177
  clip_skip=clip_skip,
1178
  )
1179
+ # print promtp embed shape
1180
+ print("prompt_embeds shape after encoding")
1181
+ print(prompt_embeds.shape)
1182
+
1183
  # For classifier free guidance, we need to do two forward passes.
1184
  # Here we concatenate the unconditional and text embeddings into a single batch
1185
  # to avoid doing two forward passes
1186
  if do_classifier_free_guidance:
1187
+ # concatenate negative prompt embeddings with prompt embeddings on a new dimension after the first batch dimension
1188
+ prompt_embeds = torch.stack([negative_prompt_embeds, prompt_embeds], dim=1)
1189
 
1190
  if ip_adapter_image is not None:
1191
  output_hidden_state = False if isinstance(self.unet.encoder_hid_proj, ImageProjection) else True
 
1407
 
1408
  # Denoising loop
1409
  num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
1410
+ # get the number of prompt from the 1st dimension of prompt_embeds
1411
+ num_prompts = prompt_embeds.shape[0]
1412
  with self.progress_bar(total=len(timesteps)) as progress_bar:
1413
  for i, t in enumerate(timesteps):
1414
  noise_pred_uncond_sum = torch.zeros_like(latents).to(device).to(dtype=torch.float16)
 
1428
  latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1429
 
1430
  if self.controlnet != None and i < int(control_end*num_inference_steps):
 
 
 
1431
 
1432
  current_context_conditioning_frames = conditioning_frames[current_context_indexes, :, :, :]
1433
  current_context_conditioning_frames = torch.cat([current_context_conditioning_frames] * 2) if do_classifier_free_guidance else current_context_conditioning_frames
 
1458
  )
1459
 
1460
 
1461
+ # get the current prompt index based on the current context position (for blending between multiple prompts)
1462
+ context_position = current_context_indexes[0] % context_size
1463
+ current_prompt_index = int(context_position / (context_size / num_prompts))
1464
+
1465
  down_block_res_samples, mid_block_res_sample = self.controlnet(
1466
  control_model_input,
1467
  t,
 
1472
  return_dict=False,
1473
  )
1474
 
 
1475
  # predict the noise residual with the added controlnet residuals
1476
  noise_pred = self.unet(
1477
  latent_model_input,
1478
  t,
1479
+ encoder_hidden_states=prompt_embeds[current_prompt_index],
1480
  cross_attention_kwargs=cross_attention_kwargs,
1481
  added_cond_kwargs=added_cond_kwargs,
1482
  down_block_additional_residuals=down_block_res_samples,
 
1485
 
1486
  else:
1487
  # predict the noise residual without contorlnet
 
 
1488
  noise_pred = self.unet(
1489
  latent_model_input,
1490
  t,
 
1494
  ).sample
1495
 
1496
  if do_classifier_free_guidance:
 
 
 
 
 
 
 
1497
 
1498
  noise_pred_uncond, noise_pred_text = torch.chunk(noise_pred, 2, dim=0)
 
 
 
 
1499
 
1500
  # Perform batch addition
1501
  noise_pred_uncond_sum[..., current_context_indexes, :, :] += noise_pred_uncond