Update pipeline.py
Browse files- pipeline.py +13 -41
pipeline.py
CHANGED
@@ -136,16 +136,13 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
|
|
136 |
EulerAncestralDiscreteScheduler,
|
137 |
DPMSolverMultistepScheduler,
|
138 |
],
|
139 |
-
controlnet: Union[ControlNetModel,
|
140 |
-
feature_extractor: CLIPImageProcessor = None,
|
141 |
-
image_encoder: CLIPVisionModelWithProjection = None,
|
142 |
):
|
143 |
super().__init__()
|
144 |
unet = UNetMotionModel.from_unet2d(unet, motion_adapter)
|
145 |
-
|
146 |
-
if isinstance(controlnet, (list, tuple)):
|
147 |
-
controlnet = MultiControlNetModel(controlnet)
|
148 |
-
|
149 |
self.register_modules(
|
150 |
vae=vae,
|
151 |
text_encoder=text_encoder,
|
@@ -1260,31 +1257,15 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
|
|
1260 |
# select the relevent context from the latents
|
1261 |
current_context_latents = latents[:, :, current_context_indexes, :, :]
|
1262 |
|
1263 |
-
if self.controlnet != None:
|
1264 |
-
# if we are using multiple controlnets, select the context window for each controlnet
|
1265 |
-
if isinstance(controlnet, MultiControlNetModel):
|
1266 |
-
print("lengt of conditioning_frames", len(conditioning_frames))
|
1267 |
-
current_context_conditioning_frames = [conditioning_frames[c][current_context_indexes, :, :, :] for c in range(len(controlnet.nets))]
|
1268 |
-
# move to device
|
1269 |
-
current_context_conditioning_frames = [c.to(device) for c in current_context_conditioning_frames]
|
1270 |
-
# print shape of curent context conditioning frames [0]
|
1271 |
-
print("shape of current context conditioning frames", current_context_conditioning_frames[0].shape)
|
1272 |
-
# print device
|
1273 |
-
print("device of current context conditioning frames", current_context_conditioning_frames[0].device)
|
1274 |
-
else:
|
1275 |
-
# select the relevent context from the conditioning frames of shape (frame_number, channel, height, width)
|
1276 |
-
current_context_conditioning_frames = conditioning_frames[current_context_indexes, :, :, :]
|
1277 |
-
current_context_conditioning_frames = current_context_conditioning_frames.to(device)
|
1278 |
-
else:
|
1279 |
-
current_context_conditioning_frames = None
|
1280 |
-
|
1281 |
-
|
1282 |
# expand the latents if we are doing classifier free guidance
|
1283 |
latent_model_input = torch.cat([current_context_latents] * 2) if do_classifier_free_guidance else current_context_latents
|
1284 |
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
1285 |
|
1286 |
|
1287 |
if self.controlnet != None:
|
|
|
|
|
|
|
1288 |
if guess_mode and self.do_classifier_free_guidance:
|
1289 |
# Infer ControlNet only for the conditional batch.
|
1290 |
control_model_input = latents
|
@@ -1302,21 +1283,12 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
|
|
1302 |
if isinstance(controlnet_cond_scale, list):
|
1303 |
controlnet_cond_scale = controlnet_cond_scale[0]
|
1304 |
cond_scale = controlnet_cond_scale * controlnet_keep[i]
|
1305 |
-
|
1306 |
control_model_input = torch.transpose(control_model_input, 1, 2)
|
1307 |
-
control_model_input = control_model_input.reshape(
|
1308 |
-
|
1309 |
-
|
1310 |
-
|
1311 |
-
print("device of control_model_input", control_model_input.device)
|
1312 |
-
print("device of controlnet_prompt_embeds", controlnet_prompt_embeds.device)
|
1313 |
-
print("device of current_context_conditioning_frames", current_context_conditioning_frames.device)
|
1314 |
-
print("shape of control_model_input", current_context_conditioning_frames.shape)
|
1315 |
-
print("device of cond_scale", cond_scale.device)
|
1316 |
-
# print error
|
1317 |
-
except Exception as e:
|
1318 |
-
print("error", e)
|
1319 |
-
|
1320 |
down_block_res_samples, mid_block_res_sample = self.controlnet(
|
1321 |
control_model_input,
|
1322 |
t,
|
@@ -1339,7 +1311,7 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
|
|
1339 |
).sample
|
1340 |
|
1341 |
else:
|
1342 |
-
# predict the noise residual
|
1343 |
noise_pred = self.unet(
|
1344 |
latent_model_input,
|
1345 |
t,
|
|
|
136 |
EulerAncestralDiscreteScheduler,
|
137 |
DPMSolverMultistepScheduler,
|
138 |
],
|
139 |
+
controlnet: Union[ControlNetModel, MultiControlNetModel],
|
140 |
+
feature_extractor: Optional[CLIPImageProcessor] = None,
|
141 |
+
image_encoder: Optional[CLIPVisionModelWithProjection] = None,
|
142 |
):
|
143 |
super().__init__()
|
144 |
unet = UNetMotionModel.from_unet2d(unet, motion_adapter)
|
145 |
+
|
|
|
|
|
|
|
146 |
self.register_modules(
|
147 |
vae=vae,
|
148 |
text_encoder=text_encoder,
|
|
|
1257 |
# select the relevent context from the latents
|
1258 |
current_context_latents = latents[:, :, current_context_indexes, :, :]
|
1259 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1260 |
# expand the latents if we are doing classifier free guidance
|
1261 |
latent_model_input = torch.cat([current_context_latents] * 2) if do_classifier_free_guidance else current_context_latents
|
1262 |
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
1263 |
|
1264 |
|
1265 |
if self.controlnet != None:
|
1266 |
+
|
1267 |
+
current_context_conditioning_frames = conditioning_frames[current_context_indexes, :, :, :]
|
1268 |
+
|
1269 |
if guess_mode and self.do_classifier_free_guidance:
|
1270 |
# Infer ControlNet only for the conditional batch.
|
1271 |
control_model_input = latents
|
|
|
1283 |
if isinstance(controlnet_cond_scale, list):
|
1284 |
controlnet_cond_scale = controlnet_cond_scale[0]
|
1285 |
cond_scale = controlnet_cond_scale * controlnet_keep[i]
|
1286 |
+
|
1287 |
control_model_input = torch.transpose(control_model_input, 1, 2)
|
1288 |
+
control_model_input = control_model_input.reshape(
|
1289 |
+
(-1, control_model_input.shape[2], control_model_input.shape[3], control_model_input.shape[4])
|
1290 |
+
)
|
1291 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1292 |
down_block_res_samples, mid_block_res_sample = self.controlnet(
|
1293 |
control_model_input,
|
1294 |
t,
|
|
|
1311 |
).sample
|
1312 |
|
1313 |
else:
|
1314 |
+
# predict the noise residual without contorlnet
|
1315 |
noise_pred = self.unet(
|
1316 |
latent_model_input,
|
1317 |
t,
|