Spaces:
Sleeping
Sleeping
zhiweili
commited on
Commit
·
7742553
1
Parent(s):
82da816
add segment mask
Browse files- app.py +64 -14
- checkpoints/selfie_multiclass_256x256.tflite +3 -0
- requirements.txt +2 -1
app.py
CHANGED
@@ -1,7 +1,22 @@
|
|
1 |
import gradio as gr
|
2 |
-
|
|
|
3 |
import torch
|
|
|
|
|
4 |
from diffusers import AutoPipelineForImage2Image, DPMSolverMultistepScheduler
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
base_model = "SG161222/RealVisXL_V4.0"
|
7 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
@@ -13,29 +28,64 @@ pipeline.scheduler = DPMSolverMultistepScheduler.from_config(pipeline.scheduler.
|
|
13 |
pipeline.to(device)
|
14 |
generator = torch.Generator(device).manual_seed(0)
|
15 |
|
16 |
-
def image_to_image(input_image, prompt, guidance_scale, num_inference_steps):
|
17 |
-
# resize image to 512x512
|
18 |
-
input_image = input_image.resize((512, 512))
|
19 |
# Generate the output image
|
20 |
output_image = pipeline(
|
21 |
generator=generator,
|
22 |
-
prompt=prompt,
|
23 |
-
|
|
|
|
|
|
|
|
|
24 |
).images[0]
|
25 |
|
26 |
return output_image
|
27 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
with gr.Blocks() as grApp:
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
35 |
generate_btn.click(
|
|
|
|
|
|
|
|
|
36 |
fn=image_to_image,
|
37 |
-
inputs=[input_image, prompt, guidance_scale, num_inference_steps],
|
38 |
-
outputs=output_image,
|
39 |
)
|
40 |
|
41 |
grApp.launch()
|
|
|
1 |
import gradio as gr
|
2 |
+
import numpy as np
|
3 |
+
import mediapipe as mp
|
4 |
import torch
|
5 |
+
|
6 |
+
from PIL import Image
|
7 |
from diffusers import AutoPipelineForImage2Image, DPMSolverMultistepScheduler
|
8 |
+
from mediapipe.tasks import python
|
9 |
+
from mediapipe.tasks.python import vision
|
10 |
+
from scipy.ndimage import binary_dilation
|
11 |
+
|
12 |
+
BG_COLOR = np.array([0, 0, 0], dtype=np.uint8) # black
|
13 |
+
MASK_COLOR = np.array([255, 255, 255], dtype=np.uint8) # white
|
14 |
+
|
15 |
+
segment_model = "checkpoints/selfie_multiclass_256x256.tflite"
|
16 |
+
base_options = python.BaseOptions(model_asset_path=segment_model)
|
17 |
+
options = vision.ImageSegmenterOptions(base_options=base_options,output_category_mask=True)
|
18 |
+
segmenter = vision.ImageSegmenter.create_from_options(options)
|
19 |
+
MASK_CATEGORY = segmenter.labels
|
20 |
|
21 |
base_model = "SG161222/RealVisXL_V4.0"
|
22 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
28 |
pipeline.to(device)
|
29 |
generator = torch.Generator(device).manual_seed(0)
|
30 |
|
31 |
+
def image_to_image(input_image, mask_image, prompt, negative_prompt, category, guidance_scale, num_inference_steps):
|
|
|
|
|
32 |
# Generate the output image
|
33 |
output_image = pipeline(
|
34 |
generator=generator,
|
35 |
+
prompt=prompt,
|
36 |
+
negative_prompt=negative_prompt,
|
37 |
+
image=input_image,
|
38 |
+
mask_image=mask_image,
|
39 |
+
guidance_scale=guidance_scale,
|
40 |
+
num_inference_steps = num_inference_steps,
|
41 |
).images[0]
|
42 |
|
43 |
return output_image
|
44 |
|
45 |
+
def segment_image(input_image, prompt, negative_prompt, category, guidance_scale, num_inference_steps):
|
46 |
+
image = mp.Image(image_format=mp.ImageFormat.SRGB, data=np.asarray(input_image))
|
47 |
+
segmentation_result = segmenter.segment(image)
|
48 |
+
category_mask = segmentation_result.category_mask
|
49 |
+
category_mask_np = category_mask.numpy_view()
|
50 |
+
target_mask = category_mask_np == MASK_CATEGORY.index(category)
|
51 |
+
|
52 |
+
# Generate solid color images for showing the output segmentation mask.
|
53 |
+
image_data = image.numpy_view()
|
54 |
+
fg_image = np.zeros(image_data.shape, dtype=np.uint8)
|
55 |
+
fg_image[:] = MASK_COLOR
|
56 |
+
bg_image = np.zeros(image_data.shape, dtype=np.uint8)
|
57 |
+
bg_image[:] = BG_COLOR
|
58 |
+
|
59 |
+
dilated_mask = binary_dilation(target_mask, iterations=4)
|
60 |
+
condition = np.stack((dilated_mask,) * 3, axis=-1) > 0.2
|
61 |
+
|
62 |
+
output_image = np.where(condition, fg_image, bg_image)
|
63 |
+
output_image = Image.fromarray(output_image)
|
64 |
+
return output_image
|
65 |
+
|
66 |
with gr.Blocks() as grApp:
|
67 |
+
with gr.Row():
|
68 |
+
with gr.Column():
|
69 |
+
prompt = gr.Textbox(lines=1, label="Prompt")
|
70 |
+
negative_prompt = gr.Textbox(lines=2, label="Negative Prompt")
|
71 |
+
category = gr.Dropdown(label='Mask Category', choices=MASK_CATEGORY, value=MASK_CATEGORY[1])
|
72 |
+
guidance_scale = gr.Slider(minimum=0, maximum=1, value=0.75, label="Guidance Scale")
|
73 |
+
num_inference_steps = gr.Slider(minimum=10, maximum=100, value=25, label="Number of Inference Steps")
|
74 |
+
input_image = gr.Image(label="Input Image", type="pil")
|
75 |
+
generate_btn = gr.Button("Generate Image")
|
76 |
+
with gr.Column():
|
77 |
+
mask_image = gr.Image(label="Mask Image", type="pil")
|
78 |
+
with gr.Column():
|
79 |
+
output_image = gr.Image(label="Output Image", type="pil")
|
80 |
+
|
81 |
generate_btn.click(
|
82 |
+
fn=segment_image,
|
83 |
+
inputs=[input_image, prompt, category, guidance_scale, num_inference_steps],
|
84 |
+
outputs=[mask_image],
|
85 |
+
).then(
|
86 |
fn=image_to_image,
|
87 |
+
inputs=[input_image, mask_image, prompt, negative_prompt, category, guidance_scale, num_inference_steps],
|
88 |
+
outputs=[output_image],
|
89 |
)
|
90 |
|
91 |
grApp.launch()
|
checkpoints/selfie_multiclass_256x256.tflite
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c6748b1253a99067ef71f7e26ca71096cd449baefa8f101900ea23016507e0e0
|
3 |
+
size 16371837
|
requirements.txt
CHANGED
@@ -2,4 +2,5 @@ gradio
|
|
2 |
torch
|
3 |
diffusers
|
4 |
transformers
|
5 |
-
accelerate
|
|
|
|
2 |
torch
|
3 |
diffusers
|
4 |
transformers
|
5 |
+
accelerate
|
6 |
+
mediapipe
|