Update pipeline.py
Browse files- pipeline.py +21 -21
pipeline.py
CHANGED
@@ -193,14 +193,14 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
|
|
193 |
EulerAncestralDiscreteScheduler,
|
194 |
DPMSolverMultistepScheduler,
|
195 |
],
|
196 |
-
|
197 |
feature_extractor: Optional[CLIPImageProcessor] = None,
|
198 |
image_encoder: Optional[CLIPVisionModelWithProjection] = None,
|
199 |
):
|
200 |
super().__init__()
|
201 |
unet = UNetMotionModel.from_unet2d(unet, motion_adapter)
|
202 |
|
203 |
-
if
|
204 |
self.register_modules(
|
205 |
vae=vae,
|
206 |
text_encoder=text_encoder,
|
@@ -218,7 +218,7 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
|
|
218 |
tokenizer=tokenizer,
|
219 |
unet=unet,
|
220 |
motion_adapter=motion_adapter,
|
221 |
-
controlnet=
|
222 |
scheduler=scheduler,
|
223 |
feature_extractor=feature_extractor,
|
224 |
image_encoder=image_encoder,
|
@@ -1117,8 +1117,8 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
|
|
1117 |
returned, otherwise a `tuple` is returned where the first element is a list with the generated frames.
|
1118 |
"""
|
1119 |
|
1120 |
-
if self.
|
1121 |
-
|
1122 |
|
1123 |
# align format for control guidance
|
1124 |
control_end = control_guidance_end
|
@@ -1127,7 +1127,7 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
|
|
1127 |
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
|
1128 |
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
|
1129 |
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
|
1130 |
-
mult = len(
|
1131 |
control_guidance_start, control_guidance_end = (
|
1132 |
mult * [control_guidance_start],
|
1133 |
mult * [control_guidance_end],
|
@@ -1155,14 +1155,14 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
|
|
1155 |
|
1156 |
device = self._execution_device
|
1157 |
|
1158 |
-
if self.
|
1159 |
-
if isinstance(
|
1160 |
-
controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(
|
1161 |
|
1162 |
global_pool_conditions = (
|
1163 |
-
|
1164 |
-
if isinstance(
|
1165 |
-
else
|
1166 |
)
|
1167 |
guess_mode = guess_mode or global_pool_conditions
|
1168 |
|
@@ -1201,8 +1201,8 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
|
|
1201 |
if do_classifier_free_guidance:
|
1202 |
image_embeds = torch.cat([negative_image_embeds, image_embeds])
|
1203 |
|
1204 |
-
if self.
|
1205 |
-
if isinstance(
|
1206 |
# conditioning_frames = self.prepare_image(
|
1207 |
# image=conditioning_frames,
|
1208 |
# width=width,
|
@@ -1221,12 +1221,12 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
|
|
1221 |
batch_size=batch_size * num_videos_per_prompt * num_frames,
|
1222 |
num_images_per_prompt=num_videos_per_prompt,
|
1223 |
device=device,
|
1224 |
-
dtype=
|
1225 |
do_classifier_free_guidance=do_classifier_free_guidance,
|
1226 |
guess_mode=guess_mode,
|
1227 |
)
|
1228 |
|
1229 |
-
elif isinstance(
|
1230 |
cond_prepared_frames = []
|
1231 |
for frame_ in conditioning_frames:
|
1232 |
# prepared_frame = self.prepare_image(
|
@@ -1248,7 +1248,7 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
|
|
1248 |
batch_size=batch_size * num_videos_per_prompt * num_frames,
|
1249 |
num_images_per_prompt=num_videos_per_prompt,
|
1250 |
device=device,
|
1251 |
-
dtype=
|
1252 |
do_classifier_free_guidance=do_classifier_free_guidance,
|
1253 |
guess_mode=guess_mode,
|
1254 |
)
|
@@ -1367,14 +1367,14 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
|
|
1367 |
added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
|
1368 |
|
1369 |
# 7.1 Create tensor stating which controlnets to keep
|
1370 |
-
if self.
|
1371 |
controlnet_keep = []
|
1372 |
for i in range(len(timesteps)):
|
1373 |
keeps = [
|
1374 |
1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
|
1375 |
for s, e in zip(control_guidance_start, control_guidance_end)
|
1376 |
]
|
1377 |
-
controlnet_keep.append(keeps[0] if isinstance(
|
1378 |
|
1379 |
# divide the initial latents into context groups
|
1380 |
|
@@ -1431,7 +1431,7 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
|
|
1431 |
latent_model_input = torch.cat([current_context_latents] * 2) if do_classifier_free_guidance else current_context_latents
|
1432 |
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
1433 |
|
1434 |
-
if self.
|
1435 |
|
1436 |
torch.cuda.synchronize() # Synchronize GPU
|
1437 |
control_start = time.time()
|
@@ -1465,7 +1465,7 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
|
|
1465 |
)
|
1466 |
|
1467 |
|
1468 |
-
down_block_res_samples, mid_block_res_sample = self.
|
1469 |
control_model_input,
|
1470 |
t,
|
1471 |
encoder_hidden_states=controlnet_prompt_embeds,
|
|
|
193 |
EulerAncestralDiscreteScheduler,
|
194 |
DPMSolverMultistepScheduler,
|
195 |
],
|
196 |
+
controlnets: Optional[Union[ControlNetModel, MultiControlNetModel]]=None,
|
197 |
feature_extractor: Optional[CLIPImageProcessor] = None,
|
198 |
image_encoder: Optional[CLIPVisionModelWithProjection] = None,
|
199 |
):
|
200 |
super().__init__()
|
201 |
unet = UNetMotionModel.from_unet2d(unet, motion_adapter)
|
202 |
|
203 |
+
if controlnets is None:
|
204 |
self.register_modules(
|
205 |
vae=vae,
|
206 |
text_encoder=text_encoder,
|
|
|
218 |
tokenizer=tokenizer,
|
219 |
unet=unet,
|
220 |
motion_adapter=motion_adapter,
|
221 |
+
controlnet=controlnets,
|
222 |
scheduler=scheduler,
|
223 |
feature_extractor=feature_extractor,
|
224 |
image_encoder=image_encoder,
|
|
|
1117 |
returned, otherwise a `tuple` is returned where the first element is a list with the generated frames.
|
1118 |
"""
|
1119 |
|
1120 |
+
if self.controlnets != None:
|
1121 |
+
controlnets = self.controlnets._orig_mod if is_compiled_module(self.controlnets) else self.controlnets
|
1122 |
|
1123 |
# align format for control guidance
|
1124 |
control_end = control_guidance_end
|
|
|
1127 |
elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
|
1128 |
control_guidance_end = len(control_guidance_start) * [control_guidance_end]
|
1129 |
elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
|
1130 |
+
mult = len(controlnets.nets) if isinstance(controlnets, MultiControlNetModel) else 1
|
1131 |
control_guidance_start, control_guidance_end = (
|
1132 |
mult * [control_guidance_start],
|
1133 |
mult * [control_guidance_end],
|
|
|
1155 |
|
1156 |
device = self._execution_device
|
1157 |
|
1158 |
+
if self.controlnets != None:
|
1159 |
+
if isinstance(controlnets, MultiControlNetModel) and isinstance(controlnet_conditioning_scale, float):
|
1160 |
+
controlnet_conditioning_scale = [controlnet_conditioning_scale] * len(controlnets.nets)
|
1161 |
|
1162 |
global_pool_conditions = (
|
1163 |
+
controlnets.config.global_pool_conditions
|
1164 |
+
if isinstance(controlnets, ControlNetModel)
|
1165 |
+
else controlnets.nets[0].config.global_pool_conditions
|
1166 |
)
|
1167 |
guess_mode = guess_mode or global_pool_conditions
|
1168 |
|
|
|
1201 |
if do_classifier_free_guidance:
|
1202 |
image_embeds = torch.cat([negative_image_embeds, image_embeds])
|
1203 |
|
1204 |
+
if self.controlnets != None:
|
1205 |
+
if isinstance(controlnets, ControlNetModel):
|
1206 |
# conditioning_frames = self.prepare_image(
|
1207 |
# image=conditioning_frames,
|
1208 |
# width=width,
|
|
|
1221 |
batch_size=batch_size * num_videos_per_prompt * num_frames,
|
1222 |
num_images_per_prompt=num_videos_per_prompt,
|
1223 |
device=device,
|
1224 |
+
dtype=controlnets.dtype,
|
1225 |
do_classifier_free_guidance=do_classifier_free_guidance,
|
1226 |
guess_mode=guess_mode,
|
1227 |
)
|
1228 |
|
1229 |
+
elif isinstance(controlnets, MultiControlNetModel):
|
1230 |
cond_prepared_frames = []
|
1231 |
for frame_ in conditioning_frames:
|
1232 |
# prepared_frame = self.prepare_image(
|
|
|
1248 |
batch_size=batch_size * num_videos_per_prompt * num_frames,
|
1249 |
num_images_per_prompt=num_videos_per_prompt,
|
1250 |
device=device,
|
1251 |
+
dtype=controlnets.dtype,
|
1252 |
do_classifier_free_guidance=do_classifier_free_guidance,
|
1253 |
guess_mode=guess_mode,
|
1254 |
)
|
|
|
1367 |
added_cond_kwargs = {"image_embeds": image_embeds} if ip_adapter_image is not None else None
|
1368 |
|
1369 |
# 7.1 Create tensor stating which controlnets to keep
|
1370 |
+
if self.controlnets != None:
|
1371 |
controlnet_keep = []
|
1372 |
for i in range(len(timesteps)):
|
1373 |
keeps = [
|
1374 |
1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
|
1375 |
for s, e in zip(control_guidance_start, control_guidance_end)
|
1376 |
]
|
1377 |
+
controlnet_keep.append(keeps[0] if isinstance(controlnets, ControlNetModel) else keeps)
|
1378 |
|
1379 |
# divide the initial latents into context groups
|
1380 |
|
|
|
1431 |
latent_model_input = torch.cat([current_context_latents] * 2) if do_classifier_free_guidance else current_context_latents
|
1432 |
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
1433 |
|
1434 |
+
if self.controlnets != None and i < int(control_end*num_inference_steps):
|
1435 |
|
1436 |
torch.cuda.synchronize() # Synchronize GPU
|
1437 |
control_start = time.time()
|
|
|
1465 |
)
|
1466 |
|
1467 |
|
1468 |
+
down_block_res_samples, mid_block_res_sample = self.controlnets(
|
1469 |
control_model_input,
|
1470 |
t,
|
1471 |
encoder_hidden_states=controlnet_prompt_embeds,
|