smoothieAI commited on
Commit
418ded7
·
verified ·
1 Parent(s): 5f14f2c

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +13 -41
pipeline.py CHANGED
@@ -136,16 +136,13 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
136
  EulerAncestralDiscreteScheduler,
137
  DPMSolverMultistepScheduler,
138
  ],
139
- controlnet: Union[ControlNetModel, List[ControlNetModel], Tuple[ControlNetModel], MultiControlNetModel] = None,
140
- feature_extractor: CLIPImageProcessor = None,
141
- image_encoder: CLIPVisionModelWithProjection = None,
142
  ):
143
  super().__init__()
144
  unet = UNetMotionModel.from_unet2d(unet, motion_adapter)
145
-
146
- if isinstance(controlnet, (list, tuple)):
147
- controlnet = MultiControlNetModel(controlnet)
148
-
149
  self.register_modules(
150
  vae=vae,
151
  text_encoder=text_encoder,
@@ -1260,31 +1257,15 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
1260
  # select the relevent context from the latents
1261
  current_context_latents = latents[:, :, current_context_indexes, :, :]
1262
 
1263
- if self.controlnet != None:
1264
- # if we are using multiple controlnets, select the context window for each controlnet
1265
- if isinstance(controlnet, MultiControlNetModel):
1266
- print("lengt of conditioning_frames", len(conditioning_frames))
1267
- current_context_conditioning_frames = [conditioning_frames[c][current_context_indexes, :, :, :] for c in range(len(controlnet.nets))]
1268
- # move to device
1269
- current_context_conditioning_frames = [c.to(device) for c in current_context_conditioning_frames]
1270
- # print shape of curent context conditioning frames [0]
1271
- print("shape of current context conditioning frames", current_context_conditioning_frames[0].shape)
1272
- # print device
1273
- print("device of current context conditioning frames", current_context_conditioning_frames[0].device)
1274
- else:
1275
- # select the relevent context from the conditioning frames of shape (frame_number, channel, height, width)
1276
- current_context_conditioning_frames = conditioning_frames[current_context_indexes, :, :, :]
1277
- current_context_conditioning_frames = current_context_conditioning_frames.to(device)
1278
- else:
1279
- current_context_conditioning_frames = None
1280
-
1281
-
1282
  # expand the latents if we are doing classifier free guidance
1283
  latent_model_input = torch.cat([current_context_latents] * 2) if do_classifier_free_guidance else current_context_latents
1284
  latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1285
 
1286
 
1287
  if self.controlnet != None:
 
 
 
1288
  if guess_mode and self.do_classifier_free_guidance:
1289
  # Infer ControlNet only for the conditional batch.
1290
  control_model_input = latents
@@ -1302,21 +1283,12 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
1302
  if isinstance(controlnet_cond_scale, list):
1303
  controlnet_cond_scale = controlnet_cond_scale[0]
1304
  cond_scale = controlnet_cond_scale * controlnet_keep[i]
1305
-
1306
  control_model_input = torch.transpose(control_model_input, 1, 2)
1307
- control_model_input = control_model_input.reshape((-1, control_model_input.shape[2], control_model_input.shape[3], control_model_input.shape[4]))
1308
- # print device for all inputs
1309
- try:
1310
- print("cond_scale", cond_scale)
1311
- print("device of control_model_input", control_model_input.device)
1312
- print("device of controlnet_prompt_embeds", controlnet_prompt_embeds.device)
1313
- print("device of current_context_conditioning_frames", current_context_conditioning_frames.device)
1314
- print("shape of control_model_input", current_context_conditioning_frames.shape)
1315
- print("device of cond_scale", cond_scale.device)
1316
- # print error
1317
- except Exception as e:
1318
- print("error", e)
1319
-
1320
  down_block_res_samples, mid_block_res_sample = self.controlnet(
1321
  control_model_input,
1322
  t,
@@ -1339,7 +1311,7 @@ class AnimateDiffPipeline(DiffusionPipeline, TextualInversionLoaderMixin, IPAdap
1339
  ).sample
1340
 
1341
  else:
1342
- # predict the noise residual
1343
  noise_pred = self.unet(
1344
  latent_model_input,
1345
  t,
 
136
  EulerAncestralDiscreteScheduler,
137
  DPMSolverMultistepScheduler,
138
  ],
139
+ controlnet: Union[ControlNetModel, MultiControlNetModel],
140
+ feature_extractor: Optional[CLIPImageProcessor] = None,
141
+ image_encoder: Optional[CLIPVisionModelWithProjection] = None,
142
  ):
143
  super().__init__()
144
  unet = UNetMotionModel.from_unet2d(unet, motion_adapter)
145
+
 
 
 
146
  self.register_modules(
147
  vae=vae,
148
  text_encoder=text_encoder,
 
1257
  # select the relevent context from the latents
1258
  current_context_latents = latents[:, :, current_context_indexes, :, :]
1259
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1260
  # expand the latents if we are doing classifier free guidance
1261
  latent_model_input = torch.cat([current_context_latents] * 2) if do_classifier_free_guidance else current_context_latents
1262
  latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1263
 
1264
 
1265
  if self.controlnet != None:
1266
+
1267
+ current_context_conditioning_frames = conditioning_frames[current_context_indexes, :, :, :]
1268
+
1269
  if guess_mode and self.do_classifier_free_guidance:
1270
  # Infer ControlNet only for the conditional batch.
1271
  control_model_input = latents
 
1283
  if isinstance(controlnet_cond_scale, list):
1284
  controlnet_cond_scale = controlnet_cond_scale[0]
1285
  cond_scale = controlnet_cond_scale * controlnet_keep[i]
1286
+
1287
  control_model_input = torch.transpose(control_model_input, 1, 2)
1288
+ control_model_input = control_model_input.reshape(
1289
+ (-1, control_model_input.shape[2], control_model_input.shape[3], control_model_input.shape[4])
1290
+ )
1291
+
 
 
 
 
 
 
 
 
 
1292
  down_block_res_samples, mid_block_res_sample = self.controlnet(
1293
  control_model_input,
1294
  t,
 
1311
  ).sample
1312
 
1313
  else:
1314
+ # predict the noise residual without contorlnet
1315
  noise_pred = self.unet(
1316
  latent_model_input,
1317
  t,