zhiweili commited on
Commit
7742553
·
1 Parent(s): 82da816

add segment mask

Browse files
app.py CHANGED
@@ -1,7 +1,22 @@
1
  import gradio as gr
2
- from PIL import Image
 
3
  import torch
 
 
4
  from diffusers import AutoPipelineForImage2Image, DPMSolverMultistepScheduler
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  base_model = "SG161222/RealVisXL_V4.0"
7
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -13,29 +28,64 @@ pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.
13
  pipeline.to(device)
14
  generator = torch.Generator(device).manual_seed(0)
15
 
16
- def image_to_image(input_image, prompt, guidance_scale, num_inference_steps):
17
- # resize image to 512x512
18
- input_image = input_image.resize((512, 512))
19
  # Generate the output image
20
  output_image = pipeline(
21
  generator=generator,
22
- prompt=prompt, image=input_image,
23
- guidance_scale=guidance_scale, num_inference_steps = num_inference_steps
 
 
 
 
24
  ).images[0]
25
 
26
  return output_image
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  with gr.Blocks() as grApp:
29
- input_image = gr.Image(label="Input Image", type="pil")
30
- prompt = gr.Textbox(lines=3, label="Prompt")
31
- guidance_scale = gr.Slider(minimum=0, maximum=1, value=0.75, label="Guidance Scale")
32
- num_inference_steps = gr.Slider(minimum=10, maximum=100, value=25, label="Number of Inference Steps")
33
- output_image = gr.Image(label="Output Image", type="pil")
34
- generate_btn = gr.Button("Generate Image")
 
 
 
 
 
 
 
 
35
  generate_btn.click(
 
 
 
 
36
  fn=image_to_image,
37
- inputs=[input_image, prompt, guidance_scale, num_inference_steps],
38
- outputs=output_image,
39
  )
40
 
41
  grApp.launch()
 
1
  import gradio as gr
2
+ import numpy as np
3
+ import mediapipe as mp
4
  import torch
5
+
6
+ from PIL import Image
7
  from diffusers import AutoPipelineForImage2Image, DPMSolverMultistepScheduler
8
+ from mediapipe.tasks import python
9
+ from mediapipe.tasks.python import vision
10
+ from scipy.ndimage import binary_dilation
11
+
12
+ BG_COLOR = np.array([0, 0, 0], dtype=np.uint8) # black
13
+ MASK_COLOR = np.array([255, 255, 255], dtype=np.uint8) # white
14
+
15
+ segment_model = "checkpoints/selfie_multiclass_256x256.tflite"
16
+ base_options = python.BaseOptions(model_asset_path=segment_model)
17
+ options = vision.ImageSegmenterOptions(base_options=base_options,output_category_mask=True)
18
+ segmenter = vision.ImageSegmenter.create_from_options(options)
19
+ MASK_CATEGORY = segmenter.labels
20
 
21
  base_model = "SG161222/RealVisXL_V4.0"
22
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
28
  pipeline.to(device)
29
  generator = torch.Generator(device).manual_seed(0)
30
 
31
+ def image_to_image(input_image, mask_image, prompt, negative_prompt, category, guidance_scale, num_inference_steps):
 
 
32
  # Generate the output image
33
  output_image = pipeline(
34
  generator=generator,
35
+ prompt=prompt,
36
+ negative_prompt=negative_prompt,
37
+ image=input_image,
38
+ mask_image=mask_image,
39
+ guidance_scale=guidance_scale,
40
+ num_inference_steps = num_inference_steps,
41
  ).images[0]
42
 
43
  return output_image
44
 
45
+ def segment_image(input_image, prompt, negative_prompt, category, guidance_scale, num_inference_steps):
46
+ image = mp.Image(image_format=mp.ImageFormat.SRGB, data=np.asarray(input_image))
47
+ segmentation_result = segmenter.segment(image)
48
+ category_mask = segmentation_result.category_mask
49
+ category_mask_np = category_mask.numpy_view()
50
+ target_mask = category_mask_np == MASK_CATEGORY.index(category)
51
+
52
+ # Generate solid color images for showing the output segmentation mask.
53
+ image_data = image.numpy_view()
54
+ fg_image = np.zeros(image_data.shape, dtype=np.uint8)
55
+ fg_image[:] = MASK_COLOR
56
+ bg_image = np.zeros(image_data.shape, dtype=np.uint8)
57
+ bg_image[:] = BG_COLOR
58
+
59
+ dilated_mask = binary_dilation(target_mask, iterations=4)
60
+ condition = np.stack((dilated_mask,) * 3, axis=-1) > 0.2
61
+
62
+ output_image = np.where(condition, fg_image, bg_image)
63
+ output_image = Image.fromarray(output_image)
64
+ return output_image
65
+
66
  with gr.Blocks() as grApp:
67
+ with gr.Row():
68
+ with gr.Column():
69
+ prompt = gr.Textbox(lines=1, label="Prompt")
70
+ negative_prompt = gr.Textbox(lines=2, label="Negative Prompt")
71
+ category = gr.Dropdown(label='Mask Category', choices=MASK_CATEGORY, value=MASK_CATEGORY[1])
72
+ guidance_scale = gr.Slider(minimum=0, maximum=1, value=0.75, label="Guidance Scale")
73
+ num_inference_steps = gr.Slider(minimum=10, maximum=100, value=25, label="Number of Inference Steps")
74
+ input_image = gr.Image(label="Input Image", type="pil")
75
+ generate_btn = gr.Button("Generate Image")
76
+ with gr.Column():
77
+ mask_image = gr.Image(label="Mask Image", type="pil")
78
+ with gr.Column():
79
+ output_image = gr.Image(label="Output Image", type="pil")
80
+
81
  generate_btn.click(
82
+ fn=segment_image,
83
+ inputs=[input_image, prompt, category, guidance_scale, num_inference_steps],
84
+ outputs=[mask_image],
85
+ ).then(
86
  fn=image_to_image,
87
+ inputs=[input_image, mask_image, prompt, negative_prompt, category, guidance_scale, num_inference_steps],
88
+ outputs=[output_image],
89
  )
90
 
91
  grApp.launch()
checkpoints/selfie_multiclass_256x256.tflite ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c6748b1253a99067ef71f7e26ca71096cd449baefa8f101900ea23016507e0e0
3
+ size 16371837
requirements.txt CHANGED
@@ -2,4 +2,5 @@ gradio
2
  torch
3
  diffusers
4
  transformers
5
- accelerate
 
 
2
  torch
3
  diffusers
4
  transformers
5
+ accelerate
6
+ mediapipe