Spaces:
Running
on
Zero
Running
on
Zero
Update flow
Browse files
app.py
CHANGED
@@ -8,10 +8,9 @@ from diffusers.utils import load_image
|
|
8 |
from torchvision.transforms.functional import to_tensor, gaussian_blur
|
9 |
from matplotlib import pyplot as plt
|
10 |
import gradio as gr
|
11 |
-
import spaces
|
12 |
from gradio_imageslider import ImageSlider
|
13 |
from torchvision.transforms.functional import to_pil_image, to_tensor
|
14 |
-
from PIL import ImageFilter
|
15 |
import traceback
|
16 |
|
17 |
|
@@ -22,19 +21,34 @@ def preprocess_image(input_image, device):
|
|
22 |
image = image.expand(-1, 3, -1, -1)
|
23 |
image = F.interpolate(image, (1024, 1024))
|
24 |
image = image.to(dtype).to(device)
|
25 |
-
|
26 |
return image
|
27 |
|
28 |
|
29 |
def preprocess_mask(input_mask, device):
|
30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
31 |
mask = mask.unsqueeze_(0).float() # 0 or 1
|
32 |
mask = F.interpolate(mask, (1024, 1024))
|
33 |
mask = gaussian_blur(mask, kernel_size=(77, 77))
|
34 |
mask[mask < 0.1] = 0
|
35 |
mask[mask >= 0.1] = 1
|
36 |
mask = mask.to(dtype).to(device)
|
37 |
-
|
38 |
return mask
|
39 |
|
40 |
|
@@ -42,7 +56,7 @@ def make_redder(img, mask, increase_factor=0.4):
|
|
42 |
img_redder = img.clone()
|
43 |
mask_expanded = mask.expand_as(img)
|
44 |
img_redder[0][mask_expanded[0] == 1] = torch.clamp(img_redder[0][mask_expanded[0] == 1] + increase_factor, 0, 1)
|
45 |
-
|
46 |
return img_redder
|
47 |
|
48 |
|
@@ -67,29 +81,28 @@ pipeline = DiffusionPipeline.from_pretrained(
|
|
67 |
|
68 |
if is_attention_slicing_enabled:
|
69 |
pipeline.enable_attention_slicing()
|
70 |
-
|
71 |
if is_cpu_offload_enabled:
|
72 |
pipeline.enable_model_cpu_offload()
|
73 |
|
74 |
|
75 |
-
@spaces.GPU
|
76 |
def remove(gradio_image, rm_guidance_scale=9, num_inference_steps=50, seed=42, strength=0.8):
|
77 |
try:
|
78 |
-
generator = torch.Generator(
|
79 |
prompt = "" # Set prompt to null
|
80 |
|
81 |
source_image_pure = gradio_image["background"]
|
82 |
mask_image_pure = gradio_image["layers"][0]
|
83 |
source_image = preprocess_image(source_image_pure.convert('RGB'), device)
|
84 |
mask = preprocess_mask(mask_image_pure, device)
|
85 |
-
|
86 |
START_STEP = 0 # AAS start step
|
87 |
END_STEP = int(strength * num_inference_steps) # AAS end step
|
88 |
-
LAYER = 34 # 0~23down,24~33mid,34~69up /AAS start layer
|
89 |
END_LAYER = 70 # AAS end layer
|
90 |
ss_steps = 9 # similarity suppression steps
|
91 |
ss_scale = 0.3 # similarity suppression scale
|
92 |
-
|
93 |
image = pipeline(
|
94 |
prompt=prompt,
|
95 |
image=source_image,
|
@@ -102,26 +115,25 @@ def remove(gradio_image, rm_guidance_scale=9, num_inference_steps=50, seed=42, s
|
|
102 |
ss_steps = ss_steps, # similarity suppression steps
|
103 |
ss_scale = ss_scale, # similarity suppression scale
|
104 |
AAS_start_step=START_STEP, # AAS start step
|
105 |
-
AAS_start_layer=LAYER, # AAS start layer
|
106 |
AAS_end_layer=END_LAYER, # AAS end layer
|
107 |
num_inference_steps=num_inference_steps, # number of inference steps # AAS_end_step = int(strength*num_inference_steps)
|
108 |
generator=generator,
|
109 |
-
guidance_scale=1
|
110 |
-
output_type='pt'
|
111 |
).images[0]
|
112 |
print('Inferece: DONE.')
|
113 |
-
|
114 |
pil_mask = to_pil_image(mask.squeeze(0))
|
115 |
pil_mask_blurred = pil_mask.filter(ImageFilter.GaussianBlur(radius=15))
|
116 |
mask_blurred = to_tensor(pil_mask_blurred).unsqueeze_(0).to(mask.device)
|
117 |
mask_f = 1-(1 - mask) * (1 - mask_blurred)
|
118 |
-
|
119 |
-
image_1 = image.unsqueeze(0)
|
120 |
-
|
121 |
-
return source_image_pure, pil_mask,
|
122 |
except:
|
123 |
print(traceback.format_exc())
|
124 |
-
|
125 |
|
126 |
title = """<h1 align="center">Object Remove</h1>"""
|
127 |
with gr.Blocks() as demo:
|
@@ -157,7 +169,7 @@ with gr.Blocks() as demo:
|
|
157 |
step=0.1,
|
158 |
label="Strength"
|
159 |
)
|
160 |
-
|
161 |
input_image = gr.ImageMask(
|
162 |
type="pil", label="Input Image",crop_size=(1200,1200), layers=False
|
163 |
)
|
@@ -167,11 +179,11 @@ with gr.Blocks() as demo:
|
|
167 |
run_button = gr.Button("Generate")
|
168 |
|
169 |
result = gr.Gallery(label="Generated images", show_label=False, elem_id="gallery", columns=[3], rows=[1], object_fit="contain", height="auto")
|
170 |
-
|
171 |
run_button.click(
|
172 |
fn=remove,
|
173 |
inputs=[input_image, guidance_scale, num_steps, seed, strength],
|
174 |
outputs=result,
|
175 |
)
|
176 |
-
|
177 |
-
demo.queue(max_size=12).launch(share=
|
|
|
8 |
from torchvision.transforms.functional import to_tensor, gaussian_blur
|
9 |
from matplotlib import pyplot as plt
|
10 |
import gradio as gr
|
|
|
11 |
from gradio_imageslider import ImageSlider
|
12 |
from torchvision.transforms.functional import to_pil_image, to_tensor
|
13 |
+
from PIL import ImageFilter, Image
|
14 |
import traceback
|
15 |
|
16 |
|
|
|
21 |
image = image.expand(-1, 3, -1, -1)
|
22 |
image = F.interpolate(image, (1024, 1024))
|
23 |
image = image.to(dtype).to(device)
|
24 |
+
|
25 |
return image
|
26 |
|
27 |
|
28 |
def preprocess_mask(input_mask, device):
|
29 |
+
# Split the channels
|
30 |
+
r, g, b, alpha = input_mask.split()
|
31 |
+
|
32 |
+
# Create a new image where:
|
33 |
+
# - Black areas (where RGB = 0) become white (255).
|
34 |
+
# - Transparent areas (where alpha = 0) become black (0).
|
35 |
+
new_mask = Image.new("L", input_mask.size)
|
36 |
+
|
37 |
+
for x in range(input_mask.width):
|
38 |
+
for y in range(input_mask.height):
|
39 |
+
if alpha.getpixel((x, y)) == 0: # Transparent pixel
|
40 |
+
new_mask.putpixel((x, y), 0) # Set to black
|
41 |
+
else: # Non-transparent pixel (originally black in the mask)
|
42 |
+
new_mask.putpixel((x, y), 255) # Set to white
|
43 |
+
|
44 |
+
mask = to_tensor(new_mask.convert('L'))
|
45 |
mask = mask.unsqueeze_(0).float() # 0 or 1
|
46 |
mask = F.interpolate(mask, (1024, 1024))
|
47 |
mask = gaussian_blur(mask, kernel_size=(77, 77))
|
48 |
mask[mask < 0.1] = 0
|
49 |
mask[mask >= 0.1] = 1
|
50 |
mask = mask.to(dtype).to(device)
|
51 |
+
|
52 |
return mask
|
53 |
|
54 |
|
|
|
56 |
img_redder = img.clone()
|
57 |
mask_expanded = mask.expand_as(img)
|
58 |
img_redder[0][mask_expanded[0] == 1] = torch.clamp(img_redder[0][mask_expanded[0] == 1] + increase_factor, 0, 1)
|
59 |
+
|
60 |
return img_redder
|
61 |
|
62 |
|
|
|
81 |
|
82 |
if is_attention_slicing_enabled:
|
83 |
pipeline.enable_attention_slicing()
|
84 |
+
|
85 |
if is_cpu_offload_enabled:
|
86 |
pipeline.enable_model_cpu_offload()
|
87 |
|
88 |
|
|
|
89 |
def remove(gradio_image, rm_guidance_scale=9, num_inference_steps=50, seed=42, strength=0.8):
|
90 |
try:
|
91 |
+
generator = torch.Generator('cuda').manual_seed(seed)
|
92 |
prompt = "" # Set prompt to null
|
93 |
|
94 |
source_image_pure = gradio_image["background"]
|
95 |
mask_image_pure = gradio_image["layers"][0]
|
96 |
source_image = preprocess_image(source_image_pure.convert('RGB'), device)
|
97 |
mask = preprocess_mask(mask_image_pure, device)
|
98 |
+
|
99 |
START_STEP = 0 # AAS start step
|
100 |
END_STEP = int(strength * num_inference_steps) # AAS end step
|
101 |
+
LAYER = 34 # 0~23down,24~33mid,34~69up /AAS start layer
|
102 |
END_LAYER = 70 # AAS end layer
|
103 |
ss_steps = 9 # similarity suppression steps
|
104 |
ss_scale = 0.3 # similarity suppression scale
|
105 |
+
|
106 |
image = pipeline(
|
107 |
prompt=prompt,
|
108 |
image=source_image,
|
|
|
115 |
ss_steps = ss_steps, # similarity suppression steps
|
116 |
ss_scale = ss_scale, # similarity suppression scale
|
117 |
AAS_start_step=START_STEP, # AAS start step
|
118 |
+
AAS_start_layer=LAYER, # AAS start layer
|
119 |
AAS_end_layer=END_LAYER, # AAS end layer
|
120 |
num_inference_steps=num_inference_steps, # number of inference steps # AAS_end_step = int(strength*num_inference_steps)
|
121 |
generator=generator,
|
122 |
+
guidance_scale=1
|
|
|
123 |
).images[0]
|
124 |
print('Inferece: DONE.')
|
125 |
+
|
126 |
pil_mask = to_pil_image(mask.squeeze(0))
|
127 |
pil_mask_blurred = pil_mask.filter(ImageFilter.GaussianBlur(radius=15))
|
128 |
mask_blurred = to_tensor(pil_mask_blurred).unsqueeze_(0).to(mask.device)
|
129 |
mask_f = 1-(1 - mask) * (1 - mask_blurred)
|
130 |
+
|
131 |
+
# image_1 = image.unsqueeze(0)
|
132 |
+
|
133 |
+
return source_image_pure, pil_mask, image
|
134 |
except:
|
135 |
print(traceback.format_exc())
|
136 |
+
|
137 |
|
138 |
title = """<h1 align="center">Object Remove</h1>"""
|
139 |
with gr.Blocks() as demo:
|
|
|
169 |
step=0.1,
|
170 |
label="Strength"
|
171 |
)
|
172 |
+
|
173 |
input_image = gr.ImageMask(
|
174 |
type="pil", label="Input Image",crop_size=(1200,1200), layers=False
|
175 |
)
|
|
|
179 |
run_button = gr.Button("Generate")
|
180 |
|
181 |
result = gr.Gallery(label="Generated images", show_label=False, elem_id="gallery", columns=[3], rows=[1], object_fit="contain", height="auto")
|
182 |
+
|
183 |
run_button.click(
|
184 |
fn=remove,
|
185 |
inputs=[input_image, guidance_scale, num_steps, seed, strength],
|
186 |
outputs=result,
|
187 |
)
|
188 |
+
|
189 |
+
demo.queue(max_size=12).launch(share=True)
|