Update pipeline.py
Browse files- pipeline.py +141 -10
pipeline.py
CHANGED
@@ -23,7 +23,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPV
|
|
23 |
# Updated to use absolute paths
|
24 |
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
|
25 |
from diffusers.loaders import IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
26 |
-
from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel
|
27 |
from diffusers.models.lora import adjust_lora_scale_text_encoder
|
28 |
from diffusers.models.unet_motion_model import MotionAdapter
|
29 |
from diffusers.schedulers import (
|
@@ -136,6 +136,7 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
|
|
136 |
EulerAncestralDiscreteScheduler,
|
137 |
DPMSolverMultistepScheduler,
|
138 |
],
|
|
|
139 |
feature_extractor: CLIPImageProcessor = None,
|
140 |
image_encoder: CLIPVisionModelWithProjection = None,
|
141 |
):
|
@@ -148,12 +149,16 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
|
|
148 |
tokenizer=tokenizer,
|
149 |
unet=unet,
|
150 |
motion_adapter=motion_adapter,
|
|
|
151 |
scheduler=scheduler,
|
152 |
feature_extractor=feature_extractor,
|
153 |
image_encoder=image_encoder,
|
154 |
)
|
155 |
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
156 |
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
|
|
|
|
|
|
157 |
|
158 |
def load_motion_adapter(self,motion_adapter):
|
159 |
self.register_modules(motion_adapter=motion_adapter)
|
@@ -846,6 +851,10 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
|
|
846 |
smooth_weight: Optional[float] = 0.5,
|
847 |
smooth_steps: Optional[int] = 3,
|
848 |
initial_context_scale: Optional[float] = 1.0,
|
|
|
|
|
|
|
|
|
849 |
):
|
850 |
r"""
|
851 |
The call function to the pipeline for generation.
|
@@ -910,6 +919,23 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
|
|
910 |
If `return_dict` is `True`, [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] is
|
911 |
returned, otherwise a `tuple` is returned where the first element is a list with the generated frames.
|
912 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
913 |
# 0. Default height and width to unet
|
914 |
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
915 |
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
@@ -930,6 +956,19 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
|
|
930 |
batch_size = prompt_embeds.shape[0]
|
931 |
|
932 |
device = self._execution_device
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
933 |
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
934 |
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
935 |
# corresponds to doing no classifier free guidance.
|
@@ -964,6 +1003,40 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
|
|
964 |
if do_classifier_free_guidance:
|
965 |
image_embeds = torch.cat([negative_image_embeds, image_embeds])
|
966 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
967 |
# 4. Prepare timesteps
|
968 |
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
969 |
timesteps = self.scheduler.timesteps
|
@@ -1051,6 +1124,16 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
|
|
1051 |
# 7 Add image embeds for IP-Adapter
|
1052 |
added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
|
1053 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1054 |
# divide the initial latents into context groups
|
1055 |
|
1056 |
def context_scheduler(context_size, overlap, offset, total_frames, total_timesteps):
|
@@ -1105,15 +1188,63 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
|
|
1105 |
# expand the latents if we are doing classifier free guidance
|
1106 |
latent_model_input = torch.cat([current_context_latents] * 2) if do_classifier_free_guidance else current_context_latents
|
1107 |
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1108 |
|
1109 |
-
|
1110 |
-
|
1111 |
-
|
1112 |
-
|
1113 |
-
|
1114 |
-
|
1115 |
-
|
1116 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1117 |
|
1118 |
# sum the noise predictions for the unconditional and text conditioned noise
|
1119 |
if do_classifier_free_guidance:
|
@@ -1176,4 +1307,4 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
|
|
1176 |
if not return_dict:
|
1177 |
return (video,)
|
1178 |
|
1179 |
-
return AnimateDiffPipelineOutput(frames=video)
|
|
|
23 |
# Updated to use absolute paths
|
24 |
from diffusers.image_processor import PipelineImageInput, VaeImageProcessor
|
25 |
from diffusers.loaders import IPAdapterMixin, LoraLoaderMixin, TextualInversionLoaderMixin
|
26 |
+
from diffusers.models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel, ControlNetModel, MultiControlNetModel
|
27 |
from diffusers.models.lora import adjust_lora_scale_text_encoder
|
28 |
from diffusers.models.unet_motion_model import MotionAdapter
|
29 |
from diffusers.schedulers import (
|
|
|
136 |
EulerAncestralDiscreteScheduler,
|
137 |
DPMSolverMultistepScheduler,
|
138 |
],
|
139 |
+
controlnet: Optional[Union[ControlNetModel, MultiControlNetModel]] = None,
|
140 |
feature_extractor: CLIPImageProcessor = None,
|
141 |
image_encoder: CLIPVisionModelWithProjection = None,
|
142 |
):
|
|
|
149 |
tokenizer=tokenizer,
|
150 |
unet=unet,
|
151 |
motion_adapter=motion_adapter,
|
152 |
+
controlnet=controlnet,
|
153 |
scheduler=scheduler,
|
154 |
feature_extractor=feature_extractor,
|
155 |
image_encoder=image_encoder,
|
156 |
)
|
157 |
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
158 |
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
159 |
+
self.control_image_processor = VaeImageProcessor(
|
160 |
+
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
|
161 |
+
)
|
162 |
|
163 |
def load_motion_adapter(self,motion_adapter):
|
164 |
self.register_modules(motion_adapter=motion_adapter)
|
|
|
851 |
smooth_weight: Optional[float] = 0.5,
|
852 |
smooth_steps: Optional[int] = 3,
|
853 |
initial_context_scale: Optional[float] = 1.0,
|
854 |
+
conditioning_frames: Optional[List[PipelineImageInput]] = None,
|
855 |
+
controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
|
856 |
+
control_guidance_start: Union[float, List[float]] = 0.0,
|
857 |
+
control_guidance_end: Union[float, List[float]] = 1.0,
|
858 |
):
|
859 |
r"""
|
860 |
The call function to the pipeline for generation.
|
|
|
919 |
If `return_dict` is `True`, [`~pipelines.text_to_video_synthesis.TextToVideoSDPipelineOutput`] is
|
920 |
returned, otherwise a `tuple` is returned where the first element is a list with the generated frames.
|
921 |
"""
|
922 |
+
|
923 |
+
if controlnet != None:
|
924 |
+
controlnet = self.controlnet._orig_mod if is_compiled_module(self.controlnet) else self.controlnet
|
925 |
+
|
926 |
+
# align format for control guidance
|
927 |
+
if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
|
928 |
+
control_guidance_start = len(control_guidance_end) * [control_guidance_start]
|
929 |
+
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
|
930 |
+
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
|
931 |
+
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
|
932 |
+
mult = len(controlnet.nets) if isinstance(controlnet, MultiControlNetModel) else 1
|
933 |
+
control_guidance_start, control_guidance_end = (
|
934 |
+
mult * [control_guidance_start],
|
935 |
+
mult * [control_guidance_end],
|
936 |
+
)
|
937 |
+
|
938 |
+
|
939 |
# 0. Default height and width to unet
|
940 |
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
941 |
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
|
|
956 |
batch_size = prompt_embeds.shape[0]
|
957 |
|
958 |
device = self._execution_device
|
959 |
+
|
960 |
+
if controlnet != None:
|
961 |
+
if isinstance(controlnet, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
|
962 |
+
controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnet.nets)
|
963 |
+
|
964 |
+
global_pool_conditions = (
|
965 |
+
controlnet.config.global_pool_conditions
|
966 |
+
if isinstance(controlnet, ControlNetModel)
|
967 |
+
else controlnet.nets[0].config.global_pool_conditions
|
968 |
+
)
|
969 |
+
guess_mode = guess_mode or global_pool_conditions
|
970 |
+
|
971 |
+
|
972 |
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
973 |
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
974 |
# corresponds to doing no classifier free guidance.
|
|
|
1003 |
if do_classifier_free_guidance:
|
1004 |
image_embeds = torch.cat([negative_image_embeds, image_embeds])
|
1005 |
|
1006 |
+
if controlnet != None:
|
1007 |
+
if isinstance(controlnet, ControlNetModel):
|
1008 |
+
conditioning_frames = self.prepare_image(
|
1009 |
+
image=conditioning_frames,
|
1010 |
+
width=width,
|
1011 |
+
height=height,
|
1012 |
+
batch_size=batch_size * num_videos_per_prompt * num_frames,
|
1013 |
+
num_images_per_prompt=num_videos_per_prompt,
|
1014 |
+
device=device,
|
1015 |
+
dtype=controlnet.dtype,
|
1016 |
+
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
1017 |
+
guess_mode=guess_mode,
|
1018 |
+
)
|
1019 |
+
elif isinstance(controlnet, MultiControlNetModel):
|
1020 |
+
cond_prepared_frames = []
|
1021 |
+
for frame_ in conditioning_frames:
|
1022 |
+
prepared_frame = self.prepare_image(
|
1023 |
+
image=frame_,
|
1024 |
+
width=width,
|
1025 |
+
height=height,
|
1026 |
+
batch_size=batch_size * num_videos_per_prompt * num_frames,
|
1027 |
+
num_images_per_prompt=num_videos_per_prompt,
|
1028 |
+
device=device,
|
1029 |
+
dtype=controlnet.dtype,
|
1030 |
+
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
1031 |
+
guess_mode=guess_mode,
|
1032 |
+
)
|
1033 |
+
|
1034 |
+
cond_prepared_frames.append(prepared_frame)
|
1035 |
+
|
1036 |
+
conditioning_frames = cond_prepared_frames
|
1037 |
+
else:
|
1038 |
+
assert False
|
1039 |
+
|
1040 |
# 4. Prepare timesteps
|
1041 |
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
1042 |
timesteps = self.scheduler.timesteps
|
|
|
1124 |
# 7 Add image embeds for IP-Adapter
|
1125 |
added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
|
1126 |
|
1127 |
+
# 7.1 Create tensor stating which controlnets to keep
|
1128 |
+
if controlnet != None:
|
1129 |
+
controlnet_keep = []
|
1130 |
+
for i in range(len(timesteps)):
|
1131 |
+
keeps = [
|
1132 |
+
1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
|
1133 |
+
for s, e in zip(control_guidance_start, control_guidance_end)
|
1134 |
+
]
|
1135 |
+
controlnet_keep.append(keeps[0] if isinstance(controlnet, ControlNetModel) else keeps)
|
1136 |
+
|
1137 |
# divide the initial latents into context groups
|
1138 |
|
1139 |
def context_scheduler(context_size, overlap, offset, total_frames, total_timesteps):
|
|
|
1188 |
# expand the latents if we are doing classifier free guidance
|
1189 |
latent_model_input = torch.cat([current_context_latents] * 2) if do_classifier_free_guidance else current_context_latents
|
1190 |
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
1191 |
+
|
1192 |
+
|
1193 |
+
|
1194 |
+
|
1195 |
+
if controlnet != None:
|
1196 |
+
if guess_mode and self.do_classifier_free_guidance:
|
1197 |
+
# Infer ControlNet only for the conditional batch.
|
1198 |
+
control_model_input = latents
|
1199 |
+
control_model_input = self.scheduler.scale_model_input(control_model_input, t)
|
1200 |
+
controlnet_prompt_embeds = prompt_embeds.chunk(2)[1]
|
1201 |
+
else:
|
1202 |
+
control_model_input = latent_model_input
|
1203 |
+
controlnet_prompt_embeds = prompt_embeds
|
1204 |
+
controlnet_prompt_embeds = controlnet_prompt_embeds.repeat_interleave(num_frames, dim=0)
|
1205 |
+
|
1206 |
+
if isinstance(controlnet_keep[i], list):
|
1207 |
+
cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
|
1208 |
+
else:
|
1209 |
+
controlnet_cond_scale = controlnet_conditioning_scale
|
1210 |
+
if isinstance(controlnet_cond_scale, list):
|
1211 |
+
controlnet_cond_scale = controlnet_cond_scale[0]
|
1212 |
+
cond_scale = controlnet_cond_scale * controlnet_keep[i]
|
1213 |
+
|
1214 |
+
control_model_input = torch.transpose(control_model_input, 1, 2)
|
1215 |
+
control_model_input = control_model_input.reshape(
|
1216 |
+
(-1, control_model_input.shape[2], control_model_input.shape[3], control_model_input.shape[4])
|
1217 |
+
)
|
1218 |
|
1219 |
+
down_block_res_samples, mid_block_res_sample = self.controlnet(
|
1220 |
+
control_model_input,
|
1221 |
+
t,
|
1222 |
+
encoder_hidden_states=controlnet_prompt_embeds,
|
1223 |
+
controlnet_cond=conditioning_frames,
|
1224 |
+
conditioning_scale=cond_scale,
|
1225 |
+
guess_mode=guess_mode,
|
1226 |
+
return_dict=False,
|
1227 |
+
)
|
1228 |
+
# predict the noise residual with the added controlnet residuals
|
1229 |
+
noise_pred = self.unet(
|
1230 |
+
latent_model_input,
|
1231 |
+
t,
|
1232 |
+
encoder_hidden_states=prompt_embeds,
|
1233 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
1234 |
+
added_cond_kwargs=added_cond_kwargs,
|
1235 |
+
down_block_additional_residuals=down_block_res_samples,
|
1236 |
+
mid_block_additional_residual=mid_block_res_sample,
|
1237 |
+
).sample
|
1238 |
+
|
1239 |
+
else:
|
1240 |
+
# predict the noise residual
|
1241 |
+
noise_pred = self.unet(
|
1242 |
+
latent_model_input,
|
1243 |
+
t,
|
1244 |
+
encoder_hidden_states=prompt_embeds,
|
1245 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
1246 |
+
added_cond_kwargs=added_cond_kwargs,
|
1247 |
+
).sample
|
1248 |
|
1249 |
# sum the noise predictions for the unconditional and text conditioned noise
|
1250 |
if do_classifier_free_guidance:
|
|
|
1307 |
if not return_dict:
|
1308 |
return (video,)
|
1309 |
|
1310 |
+
return AnimateDiffPipelineOutput(frames=video)
|