linoyts HF Staff commited on
Commit
a66b8ed
·
verified ·
1 Parent(s): 7421e8c

Update pipeline.py

Browse files
Files changed (1) hide show
  1. pipeline.py +7 -7
pipeline.py CHANGED
@@ -240,7 +240,7 @@ class Flex2Pipeline(FluxControlPipeline):
240
  if control_image is None:
241
  control_latents = torch.zeros(
242
  batch_size * num_images_per_prompt,
243
- 3,
244
  latent_height,
245
  latent_width,
246
  device=device,
@@ -261,12 +261,11 @@ class Flex2Pipeline(FluxControlPipeline):
261
 
262
  # apply control strength
263
  control_latents = control_latents * control_strength
264
- print("control_latents", control_latents.shape)
265
 
266
  if inpaint_image is None and inpaint_mask is None:
267
  inpaint_latents = torch.zeros(
268
  batch_size * num_images_per_prompt,
269
- 3,
270
  latent_height,
271
  latent_width,
272
  device=device,
@@ -282,7 +281,7 @@ class Flex2Pipeline(FluxControlPipeline):
282
  )
283
  else:
284
  print("inpaint_image.shape",inpaint_image.size)
285
- print("inpaint_mask.shape",inpaint_mask.size)
286
  inpaint_image = self.prepare_image(
287
  image=inpaint_image,
288
  width=width,
@@ -294,7 +293,6 @@ class Flex2Pipeline(FluxControlPipeline):
294
  )
295
  inpaint_image = self.vae.encode(inpaint_image).latent_dist.sample(generator=generator)
296
  inpaint_latents = (inpaint_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor
297
- print("inpaint_latents", inpaint_latents.shape)
298
  height_inpaint_image, width_inpaint_image = inpaint_image.shape[2:]
299
 
300
  inpaint_mask = self.prepare_image(
@@ -310,7 +308,7 @@ class Flex2Pipeline(FluxControlPipeline):
310
  inpaint_mask = inpaint_mask[:, 0:1, :, :] * 0.5 + 0.5
311
  # resize to match height_inpaint_image and width_inpaint_image
312
  inpaint_latents_mask = F.interpolate(inpaint_mask, size=(height_inpaint_image, width_inpaint_image), mode="bilinear", align_corners=False)
313
- print("inpaint_latents_mask", inpaint_latents_mask.shape)
314
  # apply inverted mask to inpaint latents
315
  inpaint_latents = inpaint_latents * (1 - inpaint_latents_mask)
316
 
@@ -443,4 +441,6 @@ class Flex2Pipeline(FluxControlPipeline):
443
  if not return_dict:
444
  return (image,)
445
 
446
- return FluxPipelineOutput(images=image)
 
 
 
240
  if control_image is None:
241
  control_latents = torch.zeros(
242
  batch_size * num_images_per_prompt,
243
+ 16,
244
  latent_height,
245
  latent_width,
246
  device=device,
 
261
 
262
  # apply control strength
263
  control_latents = control_latents * control_strength
 
264
 
265
  if inpaint_image is None and inpaint_mask is None:
266
  inpaint_latents = torch.zeros(
267
  batch_size * num_images_per_prompt,
268
+ 16,
269
  latent_height,
270
  latent_width,
271
  device=device,
 
281
  )
282
  else:
283
  print("inpaint_image.shape",inpaint_image.size)
284
+ print("inpaint_mask.shape",inpaint_mask.shape)
285
  inpaint_image = self.prepare_image(
286
  image=inpaint_image,
287
  width=width,
 
293
  )
294
  inpaint_image = self.vae.encode(inpaint_image).latent_dist.sample(generator=generator)
295
  inpaint_latents = (inpaint_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor
 
296
  height_inpaint_image, width_inpaint_image = inpaint_image.shape[2:]
297
 
298
  inpaint_mask = self.prepare_image(
 
308
  inpaint_mask = inpaint_mask[:, 0:1, :, :] * 0.5 + 0.5
309
  # resize to match height_inpaint_image and width_inpaint_image
310
  inpaint_latents_mask = F.interpolate(inpaint_mask, size=(height_inpaint_image, width_inpaint_image), mode="bilinear", align_corners=False)
311
+
312
  # apply inverted mask to inpaint latents
313
  inpaint_latents = inpaint_latents * (1 - inpaint_latents_mask)
314
 
 
441
  if not return_dict:
442
  return (image,)
443
 
444
+ return FluxPipelineOutput(images=image)
445
+
446
+