Surn commited on
Commit
091e9bc
·
1 Parent(s): 2ce0081

Address 2 bugs

Browse files

1. missing mask sketch
2. excessive Img2Img load time

Files changed (2) hide show
  1. app.py +15 -9
  2. utils/image_utils.py +2 -0
app.py CHANGED
@@ -423,15 +423,21 @@ def generate_image_lowmem(
423
  mask_parameters = {}
424
  # Load the mask image if provided
425
  if (pipeline_name == "FluxFillPipeline"):
426
- mask_image = open_image(mask_image).convert("RGBA")
427
- mask_condition_type = constants.condition_type[5]
428
- guidance_scale = 30
429
- num_inference_steps=50
430
- max_sequence_length=512
431
- print(f"\nAdded mask image.\n {mask_image.size}")
432
- mask_parameters ={
433
- "mask_image": mask_image,
434
- }
 
 
 
 
 
 
435
 
436
  # Set the random seed for reproducibility
437
  generator = torch.Generator(device=device).manual_seed(seed)
 
423
  mask_parameters = {}
424
  # Load the mask image if provided
425
  if (pipeline_name == "FluxFillPipeline"):
426
+ try:
427
+ mask_image = open_image(mask_image).convert("RGBA")
428
+ mask_condition_type = constants.condition_type[5]
429
+ guidance_scale = 30
430
+ num_inference_steps=50
431
+ max_sequence_length=512
432
+ print(f"\nAdded mask image.\n {mask_image.size}")
433
+ mask_parameters ={
434
+ "mask_image": mask_image,
435
+ }
436
+ except Exception as e:
437
+ print(f"Error loading mask image: {e}")
438
+ mask_image = None
439
+ gr.Warning("Please sketch a mask image to use the Fill model.")
440
+ raise Exception("Please sketch a mask image to use the Fill model.")
441
 
442
  # Set the random seed for reproducibility
443
  generator = torch.Generator(device=device).manual_seed(seed)
utils/image_utils.py CHANGED
@@ -61,6 +61,8 @@ def get_image_from_dict(image_path):
61
  image_path = image_path.get('composite')
62
  elif 'image' in image_path:
63
  image_path = image_path.get('image')
 
 
64
  else:
65
  print("\n Unknown image dictionary.\n")
66
  raise UserWarning("Unknown image dictionary.")
 
61
  image_path = image_path.get('composite')
62
  elif 'image' in image_path:
63
  image_path = image_path.get('image')
64
+ elif 'background' in image_path:
65
+ image_path = image_path.get('background')
66
  else:
67
  print("\n Unknown image dictionary.\n")
68
  raise UserWarning("Unknown image dictionary.")