mahan_ym
commited on
Commit
·
d8cda25
1
Parent(s):
d63f692
added remove background
Browse files- Makefile +6 -1
- src/app.py +27 -2
- src/assets/examples/test_6.jpg +3 -0
- src/assets/icons/hf-logo.svg +8 -0
- src/assets/icons/python-logo-only.svg +265 -0
- src/modal_app.py +34 -3
- src/tools.py +31 -2
Makefile
CHANGED
@@ -21,4 +21,9 @@ dev:
|
|
21 |
|
22 |
hf:
|
23 |
chmod 777 hf.sh
|
24 |
-
./hf.sh
|
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
hf:
|
23 |
chmod 777 hf.sh
|
24 |
+
./hf.sh
|
25 |
+
|
26 |
+
requirements:
|
27 |
+
uv pip compile --no-annotate pyproject.toml --no-deps --no-strip-extras --no-header \
|
28 |
+
| sed -E 's/([a-zA-Z0-9_-]+(\[[a-zA-Z0-9_,-]+\])?)[=><~!].*/\1/g' \
|
29 |
+
> requirements.txt
|
src/app.py
CHANGED
@@ -6,6 +6,7 @@ from tools import (
|
|
6 |
change_color_objects_hsv,
|
7 |
change_color_objects_lab,
|
8 |
privacy_preserve_image,
|
|
|
9 |
)
|
10 |
|
11 |
gr.set_static_paths(paths=[Path.cwd().absolute() / "assets"])
|
@@ -30,7 +31,7 @@ hsv_df_input = gr.Dataframe(
|
|
30 |
lab_df_input = gr.Dataframe(
|
31 |
headers=["Object", "New A", "New B"],
|
32 |
datatype=["str", "number", "number"],
|
33 |
-
col_count=(3,"fixed"),
|
34 |
label="Target Objects and New Settings",
|
35 |
type="array",
|
36 |
)
|
@@ -119,13 +120,37 @@ privacy_preserve_tool = gr.Interface(
|
|
119 |
],
|
120 |
)
|
121 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
122 |
demo = gr.TabbedInterface(
|
123 |
[
|
124 |
change_color_objects_hsv_tool,
|
125 |
change_color_objects_lab_tool,
|
126 |
privacy_preserve_tool,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
127 |
],
|
128 |
-
["Change Color Objects HSV", "Change Color Objects LAB", "Privacy Preserving Tool"],
|
129 |
title=title,
|
130 |
theme=gr.themes.Default(
|
131 |
primary_hue="blue",
|
|
|
6 |
change_color_objects_hsv,
|
7 |
change_color_objects_lab,
|
8 |
privacy_preserve_image,
|
9 |
+
remove_background,
|
10 |
)
|
11 |
|
12 |
gr.set_static_paths(paths=[Path.cwd().absolute() / "assets"])
|
|
|
31 |
lab_df_input = gr.Dataframe(
|
32 |
headers=["Object", "New A", "New B"],
|
33 |
datatype=["str", "number", "number"],
|
34 |
+
col_count=(3, "fixed"),
|
35 |
label="Target Objects and New Settings",
|
36 |
type="array",
|
37 |
)
|
|
|
120 |
],
|
121 |
)
|
122 |
|
123 |
+
remove_background_tool = gr.Interface(
|
124 |
+
fn=remove_background,
|
125 |
+
inputs=[
|
126 |
+
gr.Image(label="Input Image", type="pil"),
|
127 |
+
],
|
128 |
+
outputs=gr.Image(label="Output Image"),
|
129 |
+
title="Remove Image Background Tool",
|
130 |
+
description="Upload an image remove the background.",
|
131 |
+
examples=[
|
132 |
+
[
|
133 |
+
"https://raw.githubusercontent.com/mahan-ym/ImageAlfred/main/src/assets/examples/test_5.jpg",
|
134 |
+
],
|
135 |
+
[
|
136 |
+
"https://raw.githubusercontent.com/mahan-ym/ImageAlfred/main/src/assets/examples/test_6.jpg",
|
137 |
+
],
|
138 |
+
],
|
139 |
+
)
|
140 |
+
|
141 |
demo = gr.TabbedInterface(
|
142 |
[
|
143 |
change_color_objects_hsv_tool,
|
144 |
change_color_objects_lab_tool,
|
145 |
privacy_preserve_tool,
|
146 |
+
remove_background_tool,
|
147 |
+
],
|
148 |
+
[
|
149 |
+
"Change Color Objects HSV",
|
150 |
+
"Change Color Objects LAB",
|
151 |
+
"Privacy Preserving Tool",
|
152 |
+
"Remove Background Tool",
|
153 |
],
|
|
|
154 |
title=title,
|
155 |
theme=gr.themes.Default(
|
156 |
primary_hue="blue",
|
src/assets/examples/test_6.jpg
ADDED
![]() |
Git LFS Details
|
src/assets/icons/hf-logo.svg
ADDED
|
src/assets/icons/python-logo-only.svg
ADDED
|
src/modal_app.py
CHANGED
@@ -48,6 +48,10 @@ image = (
|
|
48 |
"git+https://github.com/luca-medeiros/lang-segment-anything.git",
|
49 |
gpu="A10G",
|
50 |
)
|
|
|
|
|
|
|
|
|
51 |
)
|
52 |
|
53 |
|
@@ -79,11 +83,14 @@ def lang_sam_segment(
|
|
79 |
if len(langsam_results[0]["labels"]) == 0:
|
80 |
print("No masks found for the given prompt.")
|
81 |
return None
|
82 |
-
|
83 |
print(f"found {len(langsam_results[0]['labels'])} masks for prompt: {prompt}")
|
84 |
print("labels:", langsam_results[0]["labels"])
|
85 |
print("scores:", langsam_results[0]["scores"])
|
86 |
-
print(
|
|
|
|
|
|
|
87 |
|
88 |
return langsam_results
|
89 |
|
@@ -284,7 +291,7 @@ def preserve_privacy(
|
|
284 |
|
285 |
for result in langsam_results:
|
286 |
print(f"result: {result}")
|
287 |
-
|
288 |
for i, mask in enumerate(result["masks"]):
|
289 |
if "mask_scores" in result:
|
290 |
if (
|
@@ -310,3 +317,27 @@ def preserve_privacy(
|
|
310 |
output_image_pil = Image.fromarray(img_array)
|
311 |
|
312 |
return output_image_pil
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
48 |
"git+https://github.com/luca-medeiros/lang-segment-anything.git",
|
49 |
gpu="A10G",
|
50 |
)
|
51 |
+
.pip_install(
|
52 |
+
"git+https://github.com/PramaLLC/BEN2.git#egg=ben2",
|
53 |
+
gpu="A10G",
|
54 |
+
)
|
55 |
)
|
56 |
|
57 |
|
|
|
83 |
if len(langsam_results[0]["labels"]) == 0:
|
84 |
print("No masks found for the given prompt.")
|
85 |
return None
|
86 |
+
|
87 |
print(f"found {len(langsam_results[0]['labels'])} masks for prompt: {prompt}")
|
88 |
print("labels:", langsam_results[0]["labels"])
|
89 |
print("scores:", langsam_results[0]["scores"])
|
90 |
+
print(
|
91 |
+
"masks scores:",
|
92 |
+
langsam_results[0].get("mask_scores", "No mask scores available"),
|
93 |
+
) # noqa: E501
|
94 |
|
95 |
return langsam_results
|
96 |
|
|
|
291 |
|
292 |
for result in langsam_results:
|
293 |
print(f"result: {result}")
|
294 |
+
|
295 |
for i, mask in enumerate(result["masks"]):
|
296 |
if "mask_scores" in result:
|
297 |
if (
|
|
|
317 |
output_image_pil = Image.fromarray(img_array)
|
318 |
|
319 |
return output_image_pil
|
320 |
+
|
321 |
+
|
322 |
+
@app.function(
|
323 |
+
gpu="A10G",
|
324 |
+
image=image,
|
325 |
+
volumes={volume_path: volume},
|
326 |
+
timeout=60 * 2,
|
327 |
+
)
|
328 |
+
def remove_background(image_pil: Image.Image) -> Image.Image:
|
329 |
+
from ben2 import BEN_Base
|
330 |
+
import torch
|
331 |
+
|
332 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
333 |
+
print(f"Using device: {device}")
|
334 |
+
print("type of image_pil:", type(image_pil))
|
335 |
+
model = BEN_Base.from_pretrained("PramaLLC/BEN2")
|
336 |
+
model.to(device).eval()
|
337 |
+
|
338 |
+
output_image = model.inference(
|
339 |
+
image_pil,
|
340 |
+
refine_foreground=True,
|
341 |
+
)
|
342 |
+
print(f"output type: {type(output_image)}")
|
343 |
+
return output_image
|
src/tools.py
CHANGED
@@ -9,6 +9,35 @@ from PIL import Image
|
|
9 |
modal_app_name = "ImageAlfred"
|
10 |
|
11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
def privacy_preserve_image(
|
13 |
input_img,
|
14 |
input_prompt,
|
@@ -99,7 +128,7 @@ def change_color_objects_hsv(
|
|
99 |
)
|
100 |
if not input_img:
|
101 |
raise gr.Error("input img cannot be None or empty.")
|
102 |
-
|
103 |
print("before processing input:", user_input)
|
104 |
valid_pattern = re.compile(r"^[a-zA-Z\s]+$")
|
105 |
for item in user_input:
|
@@ -198,7 +227,7 @@ def change_color_objects_lab(
|
|
198 |
raise gr.Error("input img cannot be None or empty.")
|
199 |
valid_pattern = re.compile(r"^[a-zA-Z\s]+$")
|
200 |
print("before processing input:", user_input)
|
201 |
-
|
202 |
for item in user_input:
|
203 |
if len(item) != 3:
|
204 |
raise gr.Error(
|
|
|
9 |
modal_app_name = "ImageAlfred"
|
10 |
|
11 |
|
12 |
+
def remove_background(
|
13 |
+
input_img,
|
14 |
+
) -> np.ndarray | Image.Image | str | Path | None:
|
15 |
+
"""
|
16 |
+
Remove the background of the image.
|
17 |
+
|
18 |
+
Args:
|
19 |
+
input_img: Input image or can be URL string of the image or base64 string. Cannot be None.
|
20 |
+
Returns:
|
21 |
+
bytes: Binary image data of the modified image.
|
22 |
+
""" # noqa: E501
|
23 |
+
if not input_img:
|
24 |
+
raise gr.Error("Input image cannot be None or empty.")
|
25 |
+
|
26 |
+
func = modal.Function.from_name("ImageAlfred", "remove_background")
|
27 |
+
output_pil = func.remote(
|
28 |
+
image_pil=input_img,
|
29 |
+
)
|
30 |
+
|
31 |
+
if output_pil is None:
|
32 |
+
raise gr.Error("Received None from server.")
|
33 |
+
if not isinstance(output_pil, Image.Image):
|
34 |
+
raise gr.Error(
|
35 |
+
f"Expected Image.Image from server function, got {type(output_pil)}"
|
36 |
+
)
|
37 |
+
|
38 |
+
return output_pil
|
39 |
+
|
40 |
+
|
41 |
def privacy_preserve_image(
|
42 |
input_img,
|
43 |
input_prompt,
|
|
|
128 |
)
|
129 |
if not input_img:
|
130 |
raise gr.Error("input img cannot be None or empty.")
|
131 |
+
|
132 |
print("before processing input:", user_input)
|
133 |
valid_pattern = re.compile(r"^[a-zA-Z\s]+$")
|
134 |
for item in user_input:
|
|
|
227 |
raise gr.Error("input img cannot be None or empty.")
|
228 |
valid_pattern = re.compile(r"^[a-zA-Z\s]+$")
|
229 |
print("before processing input:", user_input)
|
230 |
+
|
231 |
for item in user_input:
|
232 |
if len(item) != 3:
|
233 |
raise gr.Error(
|