smoothieAI commited on
Commit
be3d287
·
verified ·
1 Parent(s): f6f9bc5

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +21 -21
pipeline.py CHANGED
@@ -193,14 +193,14 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
193
  EulerAncestralDiscreteScheduler,
194
  DPMSolverMultistepScheduler,
195
  ],
196
- controlnet: Optional[Union[ControlNetModel, MultiControlNetModel]]=None,
197
  feature_extractor: Optional[CLIPImageProcessor] = None,
198
  image_encoder: Optional[CLIPVisionModelWithProjection] = None,
199
  ):
200
  super().__init__()
201
  unet = UNetMotionModel.from_unet2d(unet, motion_adapter)
202
 
203
- if controlnet is None:
204
  self.register_modules(
205
  vae=vae,
206
  text_encoder=text_encoder,
@@ -218,7 +218,7 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
218
  tokenizer=tokenizer,
219
  unet=unet,
220
  motion_adapter=motion_adapter,
221
- controlnet=controlnet,
222
  scheduler=scheduler,
223
  feature_extractor=feature_extractor,
224
  image_encoder=image_encoder,
@@ -1117,8 +1117,8 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
1117
  returned, otherwise a `tuple` is returned where the first element is a list with the generated frames.
1118
  """
1119
 
1120
- if self.controlnet != None:
1121
- controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
1122
 
1123
  # align format for control guidance
1124
  control_end = control_guidance_end
@@ -1127,7 +1127,7 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
1127
  elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
1128
  control_guidance_end = len(control_guidance_start) * [control_guidance_end]
1129
  elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
1130
- mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
1131
  control_guidance_start, control_guidance_end = (
1132
  mult * [control_guidance_start],
1133
  mult * [control_guidance_end],
@@ -1155,14 +1155,14 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
1155
 
1156
  device = self._execution_device
1157
 
1158
- if self.controlnet != None:
1159
- if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
1160
- controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
1161
 
1162
  global_pool_conditions = (
1163
- controlnet.config.global_pool_conditions
1164
- if isinstance(controlnet, ControlNetModel)
1165
- else controlnet.nets[0].config.global_pool_conditions
1166
  )
1167
  guess_mode = guess_mode or global_pool_conditions
1168
 
@@ -1201,8 +1201,8 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
1201
  if do_classifier_free_guidance:
1202
  image_embeds = torch.cat([negative_image_embeds, image_embeds])
1203
 
1204
- if self.controlnet != None:
1205
- if isinstance(controlnet, ControlNetModel):
1206
  # conditioning_frames = self.prepare_image(
1207
  # image=conditioning_frames,
1208
  # width=width,
@@ -1221,12 +1221,12 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
1221
  batch_size=batch_size * num_videos_per_prompt * num_frames,
1222
  num_images_per_prompt=num_videos_per_prompt,
1223
  device=device,
1224
- dtype=controlnet.dtype,
1225
  do_classifier_free_guidance=do_classifier_free_guidance,
1226
  guess_mode=guess_mode,
1227
  )
1228
 
1229
- elif isinstance(controlnet, MultiControlNetModel):
1230
  cond_prepared_frames = []
1231
  for frame_ in conditioning_frames:
1232
  # prepared_frame = self.prepare_image(
@@ -1248,7 +1248,7 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
1248
  batch_size=batch_size * num_videos_per_prompt * num_frames,
1249
  num_images_per_prompt=num_videos_per_prompt,
1250
  device=device,
1251
- dtype=controlnet.dtype,
1252
  do_classifier_free_guidance=do_classifier_free_guidance,
1253
  guess_mode=guess_mode,
1254
  )
@@ -1367,14 +1367,14 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
1367
  added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
1368
 
1369
  # 7.1 Create tensor stating which controlnets to keep
1370
- if self.controlnet != None:
1371
  controlnet_keep = []
1372
  for i in range(len(timesteps)):
1373
  keeps = [
1374
  1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
1375
  for s, e in zip(control_guidance_start, control_guidance_end)
1376
  ]
1377
- controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)
1378
 
1379
  # divide the initial latents into context groups
1380
 
@@ -1431,7 +1431,7 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
1431
  latent_model_input = torch.cat([current_context_latents] * 2) if do_classifier_free_guidance else current_context_latents
1432
  latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1433
 
1434
- if self.controlnet != None and i < int(control_end*num_inference_steps):
1435
 
1436
  torch.cuda.synchronize() # Synchronize GPU
1437
  control_start = time.time()
@@ -1465,7 +1465,7 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
1465
  )
1466
 
1467
 
1468
- down_block_res_samples, mid_block_res_sample = self.controlnet(
1469
  control_model_input,
1470
  t,
1471
  encoder_hidden_states=controlnet_prompt_embeds,
 
193
  EulerAncestralDiscreteScheduler,
194
  DPMSolverMultistepScheduler,
195
  ],
196
+ controlnets: Optional[Union[ControlNetModel, MultiControlNetModel]]=None,
197
  feature_extractor: Optional[CLIPImageProcessor] = None,
198
  image_encoder: Optional[CLIPVisionModelWithProjection] = None,
199
  ):
200
  super().__init__()
201
  unet = UNetMotionModel.from_unet2d(unet, motion_adapter)
202
 
203
+ if controlnets is None:
204
  self.register_modules(
205
  vae=vae,
206
  text_encoder=text_encoder,
 
218
  tokenizer=tokenizer,
219
  unet=unet,
220
  motion_adapter=motion_adapter,
221
+ controlnet=controlnets,
222
  scheduler=scheduler,
223
  feature_extractor=feature_extractor,
224
  image_encoder=image_encoder,
 
1117
  returned, otherwise a `tuple` is returned where the first element is a list with the generated frames.
1118
  """
1119
 
1120
+ if self.controlnets != None:
1121
+ controlnets = self.controlnets._orig_mod if is_compiled_module(self.controlnets) else self.controlnets
1122
 
1123
  # align format for control guidance
1124
  control_end = control_guidance_end
 
1127
  elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
1128
  control_guidance_end = len(control_guidance_start) * [control_guidance_end]
1129
  elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
1130
+ mult = len(controlnets.nets) if isinstance(controlnets, MultiControlNetModel) else 1
1131
  control_guidance_start, control_guidance_end = (
1132
  mult * [control_guidance_start],
1133
  mult * [control_guidance_end],
 
1155
 
1156
  device = self._execution_device
1157
 
1158
+ if self.controlnets != None:
1159
+ if isinstance(controlnets, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
1160
+ controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnets.nets)
1161
 
1162
  global_pool_conditions = (
1163
+ controlnets.config.global_pool_conditions
1164
+ if isinstance(controlnets, ControlNetModel)
1165
+ else controlnets.nets[0].config.global_pool_conditions
1166
  )
1167
  guess_mode = guess_mode or global_pool_conditions
1168
 
 
1201
  if do_classifier_free_guidance:
1202
  image_embeds = torch.cat([negative_image_embeds, image_embeds])
1203
 
1204
+ if self.controlnets != None:
1205
+ if isinstance(controlnets, ControlNetModel):
1206
  # conditioning_frames = self.prepare_image(
1207
  # image=conditioning_frames,
1208
  # width=width,
 
1221
  batch_size=batch_size * num_videos_per_prompt * num_frames,
1222
  num_images_per_prompt=num_videos_per_prompt,
1223
  device=device,
1224
+ dtype=controlnets.dtype,
1225
  do_classifier_free_guidance=do_classifier_free_guidance,
1226
  guess_mode=guess_mode,
1227
  )
1228
 
1229
+ elif isinstance(controlnets, MultiControlNetModel):
1230
  cond_prepared_frames = []
1231
  for frame_ in conditioning_frames:
1232
  # prepared_frame = self.prepare_image(
 
1248
  batch_size=batch_size * num_videos_per_prompt * num_frames,
1249
  num_images_per_prompt=num_videos_per_prompt,
1250
  device=device,
1251
+ dtype=controlnets.dtype,
1252
  do_classifier_free_guidance=do_classifier_free_guidance,
1253
  guess_mode=guess_mode,
1254
  )
 
1367
  added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
1368
 
1369
  # 7.1 Create tensor stating which controlnets to keep
1370
+ if self.controlnets != None:
1371
  controlnet_keep = []
1372
  for i in range(len(timesteps)):
1373
  keeps = [
1374
  1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
1375
  for s, e in zip(control_guidance_start, control_guidance_end)
1376
  ]
1377
+ controlnet_keep.append(keeps[0] if isinstance(controlnets, ControlNetModel) else keeps)
1378
 
1379
  # divide the initial latents into context groups
1380
 
 
1431
  latent_model_input = torch.cat([current_context_latents] * 2) if do_classifier_free_guidance else current_context_latents
1432
  latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1433
 
1434
+ if self.controlnets != None and i < int(control_end*num_inference_steps):
1435
 
1436
  torch.cuda.synchronize() # Synchronize GPU
1437
  control_start = time.time()
 
1465
  )
1466
 
1467
 
1468
+ down_block_res_samples, mid_block_res_sample = self.controlnets(
1469
  control_model_input,
1470
  t,
1471
  encoder_hidden_states=controlnet_prompt_embeds,