AIRider commited on
Commit
d5de94c
·
verified ·
1 Parent(s): 9e01353

Update src/app.py

Browse files
Files changed (1) hide show
  1. src/app.py +60 -187
src/app.py CHANGED
@@ -1,7 +1,7 @@
1
  import tempfile
2
  import time
 
3
  from collections.abc import Sequence
4
- from typing import Any, cast
5
 
6
  import gradio as gr
7
  import numpy as np
@@ -14,7 +14,6 @@ from PIL import Image
14
  from pymatting.foreground.estimate_foreground_ml import estimate_foreground_ml
15
  from refiners.fluxion.utils import no_grad
16
  from refiners.solutions import BoxSegmenter
17
- from transformers import GroundingDinoForObjectDetection, GroundingDinoProcessor
18
 
19
  BoundingBox = tuple[int, int, int, int]
20
 
@@ -23,18 +22,11 @@ pillow_heif.register_avif_opener()
23
 
24
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
 
26
- # weird dance because ZeroGPU
27
  segmenter = BoxSegmenter(device="cpu")
28
  segmenter.device = device
29
  segmenter.model = segmenter.model.to(device=segmenter.device)
30
 
31
- gd_model_path = "IDEA-Research/grounding-dino-base"
32
- gd_processor = GroundingDinoProcessor.from_pretrained(gd_model_path)
33
- gd_model = GroundingDinoForObjectDetection.from_pretrained(gd_model_path, torch_dtype=torch.float32)
34
- gd_model = gd_model.to(device=device) # type: ignore
35
- assert isinstance(gd_model, GroundingDinoForObjectDetection)
36
-
37
-
38
  def bbox_union(bboxes: Sequence[list[int]]) -> BoundingBox | None:
39
  if not bboxes:
40
  return None
@@ -48,32 +40,6 @@ def bbox_union(bboxes: Sequence[list[int]]) -> BoundingBox | None:
48
  max(bbox[3] for bbox in bboxes),
49
  )
50
 
51
-
52
- def corners_to_pixels_format(bboxes: torch.Tensor, width: int, height: int) -> torch.Tensor:
53
- x1, y1, x2, y2 = bboxes.round().to(torch.int32).unbind(-1)
54
- return torch.stack((x1.clamp_(0, width), y1.clamp_(0, height), x2.clamp_(0, width), y2.clamp_(0, height)), dim=-1)
55
-
56
-
57
- def gd_detect(img: Image.Image, prompt: str) -> BoundingBox | None:
58
- assert isinstance(gd_processor, GroundingDinoProcessor)
59
-
60
- # Grounding Dino expects a dot after each category.
61
- inputs = gd_processor(images=img, text=f"{prompt}.", return_tensors="pt").to(device=device)
62
-
63
- with no_grad():
64
- outputs = gd_model(**inputs)
65
- width, height = img.size
66
- results: dict[str, Any] = gd_processor.post_process_grounded_object_detection(
67
- outputs,
68
- inputs["input_ids"],
69
- target_sizes=[(height, width)],
70
- )[0]
71
- assert "boxes" in results and isinstance(results["boxes"], torch.Tensor)
72
-
73
- bboxes = corners_to_pixels_format(results["boxes"].cpu(), width, height)
74
- return bbox_union(bboxes.numpy().tolist())
75
-
76
-
77
  def apply_mask(
78
  img: Image.Image,
79
  mask_img: Image.Image,
@@ -86,54 +52,39 @@ def apply_mask(
86
  if defringe:
87
  # Mitigate edge halo effects via color decontamination
88
  rgb, alpha = np.asarray(img) / 255.0, np.asarray(mask_img) / 255.0
89
- foreground = cast(np.ndarray[Any, np.dtype[np.uint8]], estimate_foreground_ml(rgb, alpha))
90
  img = Image.fromarray((foreground * 255).astype("uint8"))
91
 
92
  result = Image.new("RGBA", img.size)
93
  result.paste(img, (0, 0), mask_img)
94
  return result
95
 
96
-
97
  @spaces.GPU
98
  def _gpu_process(
99
  img: Image.Image,
100
- prompt: str | BoundingBox | None,
101
  ) -> tuple[Image.Image, BoundingBox | None, list[str]]:
102
- # Because of ZeroGPU shenanigans, we need a *single* function with the
103
- # `spaces.GPU` decorator that *does not* contain postprocessing.
104
-
105
  time_log: list[str] = []
106
-
107
- if isinstance(prompt, str):
108
- t0 = time.time()
109
- bbox = gd_detect(img, prompt)
110
- time_log.append(f"detect: {time.time() - t0}")
111
- if not bbox:
112
- print(time_log[0])
113
- raise gr.Error("No object detected")
114
- else:
115
- bbox = prompt
116
-
117
  t0 = time.time()
118
  mask = segmenter(img, bbox)
119
  time_log.append(f"segment: {time.time() - t0}")
120
 
121
  return mask, bbox, time_log
122
 
123
-
124
  def _process(
125
  img: Image.Image,
126
- prompt: str | BoundingBox | None,
127
  ) -> tuple[tuple[Image.Image, Image.Image], gr.DownloadButton]:
128
  # enforce max dimensions for pymatting performance reasons
129
  if img.width > 2048 or img.height > 2048:
130
  orig_res = max(img.width, img.height)
131
  img.thumbnail((2048, 2048))
132
- if isinstance(prompt, tuple):
133
- x0, y0, x1, y1 = (int(x * 2048 / orig_res) for x in prompt)
134
- prompt = (x0, y0, x1, y1)
135
 
136
- mask, bbox, time_log = _gpu_process(img, prompt)
137
 
138
  t0 = time.time()
139
  masked_alpha = apply_mask(img, mask, defringe=True)
@@ -152,7 +103,6 @@ def _process(
152
 
153
  return (img, masked_rgb), gr.DownloadButton(value=temp.name, interactive=True)
154
 
155
-
156
  def process_bbox(prompts: dict[str, Any]) -> tuple[tuple[Image.Image, Image.Image], gr.DownloadButton]:
157
  assert isinstance(img := prompts["image"], Image.Image)
158
  assert isinstance(boxes := prompts["boxes"], list)
@@ -164,38 +114,17 @@ def process_bbox(prompts: dict[str, Any]) -> tuple[tuple[Image.Image, Image.Imag
164
  bbox = None
165
  return _process(img, bbox)
166
 
167
-
168
  def on_change_bbox(prompts: dict[str, Any] | None):
169
  return gr.update(interactive=prompts is not None)
170
 
171
-
172
- def process_prompt(img: Image.Image, prompt: str) -> tuple[tuple[Image.Image, Image.Image], gr.DownloadButton]:
173
- return _process(img, prompt)
174
-
175
-
176
- def on_change_prompt(img: Image.Image | None, prompt: str | None):
177
- return gr.update(interactive=bool(img and prompt))
178
-
179
-
180
  TITLE = """
181
  <center>
182
-
183
- <div style="
184
- background-color: #ff9100;
185
- color: #1f2937;
186
- padding: 0.5rem 1rem;
187
- font-size: 1.25rem;
188
- ">
189
- 🚀 For an optimized version of this space, try out the
190
- <a href="https://finegrain.ai/editor?utm_source=hf&utm_campaign=object-cutter" target="_blank">Finegrain Editor</a>! You'll find there all our AI tools made available in a nice UI. 🚀
191
- </div>
192
-
193
  <h1 style="font-size: 1.5rem; margin-bottom: 0.5rem;">
194
- Object Cutter Powered By Refiners
195
  </h1>
196
 
197
  <p>
198
- Create high-quality HD cutouts for any object in your image with just a text prompt — no manual work required!
199
  <br>
200
  The object will be available on a transparent background, ready to paste elsewhere.
201
  </p>
@@ -211,118 +140,62 @@ TITLE = """
211
  href="https://huggingface.co/datasets/Nfiniteai/product-masks-sample"
212
  target="_blank"
213
  >synthetic data provided by Nfinite</a>.
214
- <br>
215
- It is powered by Refiners, our open source micro-framework for simple foundation model adaptation.
216
- If you enjoyed it, please consider starring Refiners on GitHub!
217
  </p>
218
-
219
- <a href="https://github.com/finegrain-ai/refiners" target="_blank">
220
- <img src="https://img.shields.io/github/stars/finegrain-ai/refiners?style=social" />
221
- </a>
222
-
223
  </center>
224
  """
225
 
226
  with gr.Blocks() as demo:
227
  gr.HTML(TITLE)
228
-
229
- with gr.Tab("By prompt", id="tab_prompt"):
230
- with gr.Row():
231
- with gr.Column():
232
- iimg = gr.Image(type="pil", label="Input")
233
- prompt = gr.Textbox(label="What should we cut?")
234
- btn = gr.ClearButton(value="Cut Out Object", interactive=False)
235
- with gr.Column():
236
- oimg = ImageSlider(label="Before / After", show_download_button=False, interactive=False)
237
- dlbt = gr.DownloadButton("Download Cutout", interactive=False)
238
-
239
- btn.add(oimg)
240
-
241
- for inp in [iimg, prompt]:
242
- inp.change(
243
- fn=on_change_prompt,
244
- inputs=[iimg, prompt],
245
- outputs=[btn],
246
  )
247
- btn.click(
248
- fn=process_prompt,
249
- inputs=[iimg, prompt],
250
- outputs=[oimg, dlbt],
251
- )
252
-
253
- examples = [
254
- [
255
- "examples/potted-plant.jpg",
256
- "potted plant",
257
- ],
258
- [
259
- "examples/chair.jpg",
260
- "chair",
261
- ],
262
- [
263
- "examples/black-lamp.jpg",
264
- "black lamp",
265
- ],
266
- ]
267
-
268
- ex = gr.Examples(
269
- examples=examples,
270
- inputs=[iimg, prompt],
271
- outputs=[oimg, dlbt],
272
- fn=process_prompt,
273
- cache_examples=True,
274
- )
275
-
276
- with gr.Tab("By bounding box", id="tab_bb"):
277
- with gr.Row():
278
- with gr.Column():
279
- annotator = image_annotator(
280
- image_type="pil",
281
- disable_edit_boxes=True,
282
- show_download_button=False,
283
- show_share_button=False,
284
- single_box=True,
285
- label="Input",
286
- )
287
- btn = gr.ClearButton(value="Cut Out Object", interactive=False)
288
- with gr.Column():
289
- oimg = ImageSlider(label="Before / After", show_download_button=False)
290
- dlbt = gr.DownloadButton("Download Cutout", interactive=False)
291
 
292
- btn.add(oimg)
293
 
294
- annotator.change(
295
- fn=on_change_bbox,
296
- inputs=[annotator],
297
- outputs=[btn],
298
- )
299
- btn.click(
300
- fn=process_bbox,
301
- inputs=[annotator],
302
- outputs=[oimg, dlbt],
303
- )
304
-
305
- examples = [
306
- {
307
- "image": "examples/potted-plant.jpg",
308
- "boxes": [{"xmin": 51, "ymin": 511, "xmax": 639, "ymax": 1255}],
309
- },
310
- {
311
- "image": "examples/chair.jpg",
312
- "boxes": [{"xmin": 98, "ymin": 330, "xmax": 973, "ymax": 1468}],
313
- },
314
- {
315
- "image": "examples/black-lamp.jpg",
316
- "boxes": [{"xmin": 88, "ymin": 148, "xmax": 700, "ymax": 1414}],
317
- },
318
- ]
319
 
320
- ex = gr.Examples(
321
- examples=examples,
322
- inputs=[annotator],
323
- outputs=[oimg, dlbt],
324
- fn=process_bbox,
325
- cache_examples=True,
326
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
327
 
328
- demo.launch(share=False)
 
1
  import tempfile
2
  import time
3
+ from typing import Any
4
  from collections.abc import Sequence
 
5
 
6
  import gradio as gr
7
  import numpy as np
 
14
  from pymatting.foreground.estimate_foreground_ml import estimate_foreground_ml
15
  from refiners.fluxion.utils import no_grad
16
  from refiners.solutions import BoxSegmenter
 
17
 
18
  BoundingBox = tuple[int, int, int, int]
19
 
 
22
 
23
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
24
 
25
+ # Initialize segmenter
26
  segmenter = BoxSegmenter(device="cpu")
27
  segmenter.device = device
28
  segmenter.model = segmenter.model.to(device=segmenter.device)
29
 
 
 
 
 
 
 
 
30
  def bbox_union(bboxes: Sequence[list[int]]) -> BoundingBox | None:
31
  if not bboxes:
32
  return None
 
40
  max(bbox[3] for bbox in bboxes),
41
  )
42
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  def apply_mask(
44
  img: Image.Image,
45
  mask_img: Image.Image,
 
52
  if defringe:
53
  # Mitigate edge halo effects via color decontamination
54
  rgb, alpha = np.asarray(img) / 255.0, np.asarray(mask_img) / 255.0
55
+ foreground = estimate_foreground_ml(rgb, alpha)
56
  img = Image.fromarray((foreground * 255).astype("uint8"))
57
 
58
  result = Image.new("RGBA", img.size)
59
  result.paste(img, (0, 0), mask_img)
60
  return result
61
 
 
62
  @spaces.GPU
63
  def _gpu_process(
64
  img: Image.Image,
65
+ bbox: BoundingBox | None,
66
  ) -> tuple[Image.Image, BoundingBox | None, list[str]]:
 
 
 
67
  time_log: list[str] = []
68
+
 
 
 
 
 
 
 
 
 
 
69
  t0 = time.time()
70
  mask = segmenter(img, bbox)
71
  time_log.append(f"segment: {time.time() - t0}")
72
 
73
  return mask, bbox, time_log
74
 
 
75
  def _process(
76
  img: Image.Image,
77
+ bbox: BoundingBox | None,
78
  ) -> tuple[tuple[Image.Image, Image.Image], gr.DownloadButton]:
79
  # enforce max dimensions for pymatting performance reasons
80
  if img.width > 2048 or img.height > 2048:
81
  orig_res = max(img.width, img.height)
82
  img.thumbnail((2048, 2048))
83
+ if isinstance(bbox, tuple):
84
+ x0, y0, x1, y1 = (int(x * 2048 / orig_res) for x in bbox)
85
+ bbox = (x0, y0, x1, y1)
86
 
87
+ mask, bbox, time_log = _gpu_process(img, bbox)
88
 
89
  t0 = time.time()
90
  masked_alpha = apply_mask(img, mask, defringe=True)
 
103
 
104
  return (img, masked_rgb), gr.DownloadButton(value=temp.name, interactive=True)
105
 
 
106
  def process_bbox(prompts: dict[str, Any]) -> tuple[tuple[Image.Image, Image.Image], gr.DownloadButton]:
107
  assert isinstance(img := prompts["image"], Image.Image)
108
  assert isinstance(boxes := prompts["boxes"], list)
 
114
  bbox = None
115
  return _process(img, bbox)
116
 
 
117
  def on_change_bbox(prompts: dict[str, Any] | None):
118
  return gr.update(interactive=prompts is not None)
119
 
 
 
 
 
 
 
 
 
 
120
  TITLE = """
121
  <center>
 
 
 
 
 
 
 
 
 
 
 
122
  <h1 style="font-size: 1.5rem; margin-bottom: 0.5rem;">
123
+ Object Cutter With Bounding Box
124
  </h1>
125
 
126
  <p>
127
+ Create high-quality HD cutouts for any object in your image using bounding box selection.
128
  <br>
129
  The object will be available on a transparent background, ready to paste elsewhere.
130
  </p>
 
140
  href="https://huggingface.co/datasets/Nfiniteai/product-masks-sample"
141
  target="_blank"
142
  >synthetic data provided by Nfinite</a>.
 
 
 
143
  </p>
 
 
 
 
 
144
  </center>
145
  """
146
 
147
  with gr.Blocks() as demo:
148
  gr.HTML(TITLE)
149
+
150
+ with gr.Row():
151
+ with gr.Column():
152
+ annotator = image_annotator(
153
+ image_type="pil",
154
+ disable_edit_boxes=True,
155
+ show_download_button=False,
156
+ show_share_button=False,
157
+ single_box=True,
158
+ label="Input",
 
 
 
 
 
 
 
 
159
  )
160
+ btn = gr.ClearButton(value="Cut Out Object", interactive=False)
161
+ with gr.Column():
162
+ oimg = ImageSlider(label="Before / After", show_download_button=False)
163
+ dlbt = gr.DownloadButton("Download Cutout", interactive=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
 
165
+ btn.add(oimg)
166
 
167
+ annotator.change(
168
+ fn=on_change_bbox,
169
+ inputs=[annotator],
170
+ outputs=[btn],
171
+ )
172
+ btn.click(
173
+ fn=process_bbox,
174
+ inputs=[annotator],
175
+ outputs=[oimg, dlbt],
176
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
 
178
+ examples = [
179
+ {
180
+ "image": "examples/potted-plant.jpg",
181
+ "boxes": [{"xmin": 51, "ymin": 511, "xmax": 639, "ymax": 1255}],
182
+ },
183
+ {
184
+ "image": "examples/chair.jpg",
185
+ "boxes": [{"xmin": 98, "ymin": 330, "xmax": 973, "ymax": 1468}],
186
+ },
187
+ {
188
+ "image": "examples/black-lamp.jpg",
189
+ "boxes": [{"xmin": 88, "ymin": 148, "xmax": 700, "ymax": 1414}],
190
+ },
191
+ ]
192
+
193
+ ex = gr.Examples(
194
+ examples=examples,
195
+ inputs=[annotator],
196
+ outputs=[oimg, dlbt],
197
+ fn=process_bbox,
198
+ cache_examples=True,
199
+ )
200
 
201
+ demo.launch(share=False)