Bethie commited on
Commit
28da6d8
·
verified ·
1 Parent(s): 1d0c8e6

Modified pipeline ONNX

Browse files
code_inference/pipeline_sdxl_cnext_ipadapter.py CHANGED
@@ -310,7 +310,7 @@ class StableDiffusionXLControlNeXtPipeline():
310
  controlnext: ort.InferenceSession,
311
  image_proj: ort.InferenceSession,
312
  scheduler: DDPMScheduler,
313
- image_encoder: CLIPVisionModelWithProjection = None,
314
  feature_extractor: CLIPImageProcessor = None,
315
  add_watermarker: Optional[bool] = None,
316
  device=None,
@@ -1122,8 +1122,8 @@ class StableDiffusionXLControlNeXtPipeline():
1122
  image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
1123
  uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_images_per_prompt, 1)
1124
  uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
1125
- prompt_embeds = torch.cat([torch.zeros_like(prompt_embeds), image_prompt_embeds], dim=1)
1126
- negative_prompt_embeds = torch.cat([torch.zeros_like(negative_prompt_embeds), uncond_image_prompt_embeds], dim=1)
1127
 
1128
  if self.do_classifier_free_guidance:
1129
  prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
@@ -1182,7 +1182,7 @@ class StableDiffusionXLControlNeXtPipeline():
1182
  controls = self.controlnet.run(None, {'controlnext_image': controlnet_image.cpu().numpy(),
1183
  'timestep': t.unsqueeze(0).cpu().numpy().astype(np.float32),})
1184
 
1185
- scale = torch.tensor([1.00])
1186
 
1187
  noise_pred = self.unet.run(None, {'sample': latent_model_input.cpu().numpy().astype(np.float32),
1188
  'timestep': t.unsqueeze(0).cpu().numpy().astype(np.float32),
 
310
  controlnext: ort.InferenceSession,
311
  image_proj: ort.InferenceSession,
312
  scheduler: DDPMScheduler,
313
+ image_encoder: ort.InferenceSession,
314
  feature_extractor: CLIPImageProcessor = None,
315
  add_watermarker: Optional[bool] = None,
316
  device=None,
 
1122
  image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
1123
  uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_images_per_prompt, 1)
1124
  uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
1125
+ prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
1126
+ negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
1127
 
1128
  if self.do_classifier_free_guidance:
1129
  prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
 
1182
  controls = self.controlnet.run(None, {'controlnext_image': controlnet_image.cpu().numpy(),
1183
  'timestep': t.unsqueeze(0).cpu().numpy().astype(np.float32),})
1184
 
1185
+ scale = torch.tensor([control_scale])
1186
 
1187
  noise_pred = self.unet.run(None, {'sample': latent_model_input.cpu().numpy().astype(np.float32),
1188
  'timestep': t.unsqueeze(0).cpu().numpy().astype(np.float32),