smoothieAI commited on
Commit
b58a073
·
verified ·
1 Parent(s): 701a6d6

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +141 -10
pipeline.py CHANGED
@@ -23,7 +23,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPV
23
  # Updated to use absolute paths
24
  from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
25
  from diffusers.loaders import IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
26
- from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel
27
  from diffusers.models.lora import adjust_lora_scale_text_encoder
28
  from diffusers.models.unet_motion_model import MotionAdapter
29
  from diffusers.schedulers import (
@@ -136,6 +136,7 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
136
  EulerAncestralDiscreteScheduler,
137
  DPMSolverMultistepScheduler,
138
  ],
 
139
  feature_extractor: CLIPImageProcessor = None,
140
  image_encoder: CLIPVisionModelWithProjection = None,
141
  ):
@@ -148,12 +149,16 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
148
  tokenizer=tokenizer,
149
  unet=unet,
150
  motion_adapter=motion_adapter,
 
151
  scheduler=scheduler,
152
  feature_extractor=feature_extractor,
153
  image_encoder=image_encoder,
154
  )
155
  self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
156
  self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
 
 
 
157
 
158
  def load_motion_adapter(self,motion_adapter):
159
  self.register_modules(motion_adapter=motion_adapter)
@@ -846,6 +851,10 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
846
  smooth_weight: Optional[float] = 0.5,
847
  smooth_steps: Optional[int] = 3,
848
  initial_context_scale: Optional[float] = 1.0,
 
 
 
 
849
  ):
850
  r"""
851
  The call function to the pipeline for generation.
@@ -910,6 +919,23 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
910
  If `return_dict` is `True`, [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] is
911
  returned, otherwise a `tuple` is returned where the first element is a list with the generated frames.
912
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
913
  # 0. Default height and width to unet
914
  height = height or self.unet.config.sample_size * self.vae_scale_factor
915
  width = width or self.unet.config.sample_size * self.vae_scale_factor
@@ -930,6 +956,19 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
930
  batch_size = prompt_embeds.shape[0]
931
 
932
  device = self._execution_device
 
 
 
 
 
 
 
 
 
 
 
 
 
933
  # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
934
  # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
935
  # corresponds to doing no classifier free guidance.
@@ -964,6 +1003,40 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
964
  if do_classifier_free_guidance:
965
  image_embeds = torch.cat([negative_image_embeds, image_embeds])
966
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
967
  # 4. Prepare timesteps
968
  self.scheduler.set_timesteps(num_inference_steps, device=device)
969
  timesteps = self.scheduler.timesteps
@@ -1051,6 +1124,16 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
1051
  # 7 Add image embeds for IP-Adapter
1052
  added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
1053
 
 
 
 
 
 
 
 
 
 
 
1054
  # divide the initial latents into context groups
1055
 
1056
  def context_scheduler(context_size, overlap, offset, total_frames, total_timesteps):
@@ -1105,15 +1188,63 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
1105
  # expand the latents if we are doing classifier free guidance
1106
  latent_model_input = torch.cat([current_context_latents] * 2) if do_classifier_free_guidance else current_context_latents
1107
  latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1108
 
1109
- # predict the noise residual
1110
- noise_pred = self.unet(
1111
- latent_model_input,
1112
- t,
1113
- encoder_hidden_states=prompt_embeds,
1114
- cross_attention_kwargs=cross_attention_kwargs,
1115
- added_cond_kwargs=added_cond_kwargs,
1116
- ).sample
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1117
 
1118
  # sum the noise predictions for the unconditional and text conditioned noise
1119
  if do_classifier_free_guidance:
@@ -1176,4 +1307,4 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
1176
  if not return_dict:
1177
  return (video,)
1178
 
1179
- return AnimateDiffPipelineOutput(frames=video)
 
23
  # Updated to use absolute paths
24
  from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
25
  from diffusers.loaders import IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
26
+ from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel, ControlNetModel, MultiControlNetModel
27
  from diffusers.models.lora import adjust_lora_scale_text_encoder
28
  from diffusers.models.unet_motion_model import MotionAdapter
29
  from diffusers.schedulers import (
 
136
  EulerAncestralDiscreteScheduler,
137
  DPMSolverMultistepScheduler,
138
  ],
139
+ controlnet: Optional[Union[ControlNetModel, MultiControlNetModel]] = None,
140
  feature_extractor: CLIPImageProcessor = None,
141
  image_encoder: CLIPVisionModelWithProjection = None,
142
  ):
 
149
  tokenizer=tokenizer,
150
  unet=unet,
151
  motion_adapter=motion_adapter,
152
+ controlnet=controlnet,
153
  scheduler=scheduler,
154
  feature_extractor=feature_extractor,
155
  image_encoder=image_encoder,
156
  )
157
  self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
158
  self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
159
+ self.control_image_processor = VaeImageProcessor(
160
+ vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
161
+ )
162
 
163
  def load_motion_adapter(self,motion_adapter):
164
  self.register_modules(motion_adapter=motion_adapter)
 
851
  smooth_weight: Optional[float] = 0.5,
852
  smooth_steps: Optional[int] = 3,
853
  initial_context_scale: Optional[float] = 1.0,
854
+ conditioning_frames: Optional[List[PipelineImageInput]] = None,
855
+ controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
856
+ control_guidance_start: Union[float, List[float]] = 0.0,
857
+ control_guidance_end: Union[float, List[float]] = 1.0,
858
  ):
859
  r"""
860
  The call function to the pipeline for generation.
 
919
  If `return_dict` is `True`, [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] is
920
  returned, otherwise a `tuple` is returned where the first element is a list with the generated frames.
921
  """
922
+
923
+ if controlnet != None:
924
+ controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
925
+
926
+ # align format for control guidance
927
+ if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
928
+ control_guidance_start = len(control_guidance_end) * [control_guidance_start]
929
+ elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
930
+ control_guidance_end = len(control_guidance_start) * [control_guidance_end]
931
+ elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
932
+ mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
933
+ control_guidance_start, control_guidance_end = (
934
+ mult * [control_guidance_start],
935
+ mult * [control_guidance_end],
936
+ )
937
+
938
+
939
  # 0. Default height and width to unet
940
  height = height or self.unet.config.sample_size * self.vae_scale_factor
941
  width = width or self.unet.config.sample_size * self.vae_scale_factor
 
956
  batch_size = prompt_embeds.shape[0]
957
 
958
  device = self._execution_device
959
+
960
+ if controlnet != None:
961
+ if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
962
+ controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
963
+
964
+ global_pool_conditions = (
965
+ controlnet.config.global_pool_conditions
966
+ if isinstance(controlnet, ControlNetModel)
967
+ else controlnet.nets[0].config.global_pool_conditions
968
+ )
969
+ guess_mode = guess_mode or global_pool_conditions
970
+
971
+
972
  # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
973
  # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
974
  # corresponds to doing no classifier free guidance.
 
1003
  if do_classifier_free_guidance:
1004
  image_embeds = torch.cat([negative_image_embeds, image_embeds])
1005
 
1006
+ if controlnet != None:
1007
+ if isinstance(controlnet, ControlNetModel):
1008
+ conditioning_frames = self.prepare_image(
1009
+ image=conditioning_frames,
1010
+ width=width,
1011
+ height=height,
1012
+ batch_size=batch_size * num_videos_per_prompt * num_frames,
1013
+ num_images_per_prompt=num_videos_per_prompt,
1014
+ device=device,
1015
+ dtype=controlnet.dtype,
1016
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
1017
+ guess_mode=guess_mode,
1018
+ )
1019
+ elif isinstance(controlnet, MultiControlNetModel):
1020
+ cond_prepared_frames = []
1021
+ for frame_ in conditioning_frames:
1022
+ prepared_frame = self.prepare_image(
1023
+ image=frame_,
1024
+ width=width,
1025
+ height=height,
1026
+ batch_size=batch_size * num_videos_per_prompt * num_frames,
1027
+ num_images_per_prompt=num_videos_per_prompt,
1028
+ device=device,
1029
+ dtype=controlnet.dtype,
1030
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
1031
+ guess_mode=guess_mode,
1032
+ )
1033
+
1034
+ cond_prepared_frames.append(prepared_frame)
1035
+
1036
+ conditioning_frames = cond_prepared_frames
1037
+ else:
1038
+ assert False
1039
+
1040
  # 4. Prepare timesteps
1041
  self.scheduler.set_timesteps(num_inference_steps, device=device)
1042
  timesteps = self.scheduler.timesteps
 
1124
  # 7 Add image embeds for IP-Adapter
1125
  added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
1126
 
1127
+ # 7.1 Create tensor stating which controlnets to keep
1128
+ if controlnet != None:
1129
+ controlnet_keep = []
1130
+ for i in range(len(timesteps)):
1131
+ keeps = [
1132
+ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
1133
+ for s, e in zip(control_guidance_start, control_guidance_end)
1134
+ ]
1135
+ controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)
1136
+
1137
  # divide the initial latents into context groups
1138
 
1139
  def context_scheduler(context_size, overlap, offset, total_frames, total_timesteps):
 
1188
  # expand the latents if we are doing classifier free guidance
1189
  latent_model_input = torch.cat([current_context_latents] * 2) if do_classifier_free_guidance else current_context_latents
1190
  latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1191
+
1192
+
1193
+
1194
+
1195
+ if controlnet != None:
1196
+ if guess_mode and self.do_classifier_free_guidance:
1197
+ # Infer ControlNet only for the conditional batch.
1198
+ control_model_input = latents
1199
+ control_model_input = self.scheduler.scale_model_input(control_model_input, t)
1200
+ controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
1201
+ else:
1202
+ control_model_input = latent_model_input
1203
+ controlnet_prompt_embeds = prompt_embeds
1204
+ controlnet_prompt_embeds = controlnet_prompt_embeds.repeat_interleave(num_frames, dim=0)
1205
+
1206
+ if isinstance(controlnet_keep[i], list):
1207
+ cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
1208
+ else:
1209
+ controlnet_cond_scale = controlnet_conditioning_scale
1210
+ if isinstance(controlnet_cond_scale, list):
1211
+ controlnet_cond_scale = controlnet_cond_scale[0]
1212
+ cond_scale = controlnet_cond_scale * controlnet_keep[i]
1213
+
1214
+ control_model_input = torch.transpose(control_model_input, 1, 2)
1215
+ control_model_input = control_model_input.reshape(
1216
+ (-1, control_model_input.shape[2], control_model_input.shape[3], control_model_input.shape[4])
1217
+ )
1218
 
1219
+ down_block_res_samples, mid_block_res_sample = self.controlnet(
1220
+ control_model_input,
1221
+ t,
1222
+ encoder_hidden_states=controlnet_prompt_embeds,
1223
+ controlnet_cond=conditioning_frames,
1224
+ conditioning_scale=cond_scale,
1225
+ guess_mode=guess_mode,
1226
+ return_dict=False,
1227
+ )
1228
+ # predict the noise residual with the added controlnet residuals
1229
+ noise_pred = self.unet(
1230
+ latent_model_input,
1231
+ t,
1232
+ encoder_hidden_states=prompt_embeds,
1233
+ cross_attention_kwargs=cross_attention_kwargs,
1234
+ added_cond_kwargs=added_cond_kwargs,
1235
+ down_block_additional_residuals=down_block_res_samples,
1236
+ mid_block_additional_residual=mid_block_res_sample,
1237
+ ).sample
1238
+
1239
+ else:
1240
+ # predict the noise residual
1241
+ noise_pred = self.unet(
1242
+ latent_model_input,
1243
+ t,
1244
+ encoder_hidden_states=prompt_embeds,
1245
+ cross_attention_kwargs=cross_attention_kwargs,
1246
+ added_cond_kwargs=added_cond_kwargs,
1247
+ ).sample
1248
 
1249
  # sum the noise predictions for the unconditional and text conditioned noise
1250
  if do_classifier_free_guidance:
 
1307
  if not return_dict:
1308
  return (video,)
1309
 
1310
+ return AnimateDiffPipelineOutput(frames=video)