mahan_ym commited on
Commit
e4ccc11
·
1 Parent(s): e608228

change foundation model from grounding dino to clip

Browse files
Files changed (3) hide show
  1. src/app.py +21 -8
  2. src/modal_app.py +228 -90
  3. src/tools.py +7 -7
src/app.py CHANGED
@@ -31,7 +31,7 @@ lab_df_input = gr.Dataframe(
31
  headers=["Object", "New A", "New B"],
32
  datatype=["str", "number", "number"],
33
  col_count=(3, "fixed"),
34
- label="Target Objects and New Settings",
35
  type="array",
36
  )
37
 
@@ -78,15 +78,15 @@ change_color_objects_lab_tool = gr.Interface(
78
  examples=[
79
  [
80
  "https://raw.githubusercontent.com/mahan-ym/ImageAlfred/main/src/assets/examples/test_1.jpg",
81
- [["pants", 128, 1]],
82
  ],
83
  [
84
  "https://raw.githubusercontent.com/mahan-ym/ImageAlfred/main/src/assets/examples/test_4.jpg",
85
- [["desk", 15, 0.5], ["left cup", 40, 1.1]],
86
  ],
87
  [
88
  "https://raw.githubusercontent.com/mahan-ym/ImageAlfred/main/src/assets/examples/test_5.jpg",
89
- [["suits", 60, 1.5], ["pants", 10, 0.8]],
90
  ],
91
  ],
92
  )
@@ -117,6 +117,16 @@ privacy_preserve_tool = gr.Interface(
117
  "license plate.",
118
  10,
119
  ],
 
 
 
 
 
 
 
 
 
 
120
  ],
121
  )
122
 
@@ -135,21 +145,24 @@ remove_background_tool = gr.Interface(
135
  [
136
  "https://raw.githubusercontent.com/mahan-ym/ImageAlfred/main/src/assets/examples/test_6.jpg",
137
  ],
 
 
 
138
  ],
139
  )
140
 
141
  demo = gr.TabbedInterface(
142
  [
143
- change_color_objects_hsv_tool,
144
- change_color_objects_lab_tool,
145
  privacy_preserve_tool,
146
  remove_background_tool,
 
 
147
  ],
148
  [
149
- "Change Color Objects HSV",
150
- "Change Color Objects LAB",
151
  "Privacy Preserving Tool",
152
  "Remove Background Tool",
 
 
153
  ],
154
  title=title,
155
  theme=gr.themes.Default(
 
31
  headers=["Object", "New A", "New B"],
32
  datatype=["str", "number", "number"],
33
  col_count=(3, "fixed"),
34
+ label="Target Objects and New Settings.(0-255 -- 128 = Neutral)",
35
  type="array",
36
  )
37
 
 
78
  examples=[
79
  [
80
  "https://raw.githubusercontent.com/mahan-ym/ImageAlfred/main/src/assets/examples/test_1.jpg",
81
+ [["pants", 112, 128]],
82
  ],
83
  [
84
  "https://raw.githubusercontent.com/mahan-ym/ImageAlfred/main/src/assets/examples/test_4.jpg",
85
+ [["desk", 166, 193]],
86
  ],
87
  [
88
  "https://raw.githubusercontent.com/mahan-ym/ImageAlfred/main/src/assets/examples/test_5.jpg",
89
+ [["suits coat", 110, 133]],
90
  ],
91
  ],
92
  )
 
117
  "license plate.",
118
  10,
119
  ],
120
+ [
121
+ "https://raw.githubusercontent.com/mahan-ym/ImageAlfred/main/src/assets/examples/test_8.jpg",
122
+ "face.",
123
+ 15,
124
+ ],
125
+ [
126
+ "https://raw.githubusercontent.com/mahan-ym/ImageAlfred/main/src/assets/examples/test_6.jpg",
127
+ "face.",
128
+ 20,
129
+ ],
130
  ],
131
  )
132
 
 
145
  [
146
  "https://raw.githubusercontent.com/mahan-ym/ImageAlfred/main/src/assets/examples/test_6.jpg",
147
  ],
148
+ [
149
+ "https://raw.githubusercontent.com/mahan-ym/ImageAlfred/main/src/assets/examples/test_8.jpg",
150
+ ],
151
  ],
152
  )
153
 
154
  demo = gr.TabbedInterface(
155
  [
 
 
156
  privacy_preserve_tool,
157
  remove_background_tool,
158
+ change_color_objects_hsv_tool,
159
+ change_color_objects_lab_tool,
160
  ],
161
  [
 
 
162
  "Privacy Preserving Tool",
163
  "Remove Background Tool",
164
+ "Change Color Objects HSV",
165
+ "Change Color Objects LAB",
166
  ],
167
  title=title,
168
  theme=gr.themes.Default(
src/modal_app.py CHANGED
@@ -5,7 +5,6 @@ import cv2
5
  import modal
6
  import numpy as np
7
  from PIL import Image
8
- from rapidfuzz import process
9
 
10
  app = modal.App("ImageAlfred")
11
 
@@ -30,14 +29,16 @@ image = (
30
  "TORCH_HOME": TORCH_HOME,
31
  }
32
  )
33
- .apt_install("git")
 
 
34
  .pip_install(
35
  "huggingface-hub",
36
  "hf_transfer",
37
  "Pillow",
38
  "numpy",
 
39
  "opencv-contrib-python-headless",
40
- "RapidFuzz",
41
  gpu="A10G",
42
  )
43
  .pip_install(
@@ -46,10 +47,8 @@ image = (
46
  index_url="https://download.pytorch.org/whl/cu124",
47
  gpu="A10G",
48
  )
49
- .pip_install(
50
- "git+https://github.com/luca-medeiros/lang-segment-anything.git",
51
- gpu="A10G",
52
- )
53
  .pip_install(
54
  "git+https://github.com/PramaLLC/BEN2.git#egg=ben2",
55
  gpu="A10G",
@@ -58,43 +57,180 @@ image = (
58
 
59
 
60
  @app.function(
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  gpu="A10G",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  image=image,
 
63
  volumes={volume_path: volume},
64
- # min_containers=1,
65
  timeout=60 * 3,
66
  )
67
- def lang_sam_segment(
68
  image_pil: Image.Image,
69
- prompt: str,
70
- box_threshold=0.3,
71
- text_threshold=0.25,
72
- ) -> list:
73
- """Segments an image using LangSAM based on a text prompt.
74
- This function uses LangSAM to segment objects in the image based on the provided prompt.
75
- """ # noqa: E501
76
- from lang_sam import LangSAM # type: ignore
77
-
78
- model = LangSAM(sam_type="sam2.1_hiera_large")
79
- langsam_results = model.predict(
80
- images_pil=[image_pil],
81
- texts_prompt=[prompt],
82
- box_threshold=box_threshold,
83
- text_threshold=text_threshold,
 
84
  )
85
- if len(langsam_results[0]["labels"]) == 0:
86
- print("No masks found for the given prompt.")
87
- return None
88
 
89
- print(f"found {len(langsam_results[0]['labels'])} masks for prompt: {prompt}")
90
- print("labels:", langsam_results[0]["labels"])
91
- print("scores:", langsam_results[0]["scores"])
92
- print(
93
- "masks scores:",
94
- langsam_results[0].get("mask_scores", "No mask scores available"),
95
- ) # noqa: E501
96
 
97
- return langsam_results
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98
 
99
 
100
  @app.function(
@@ -128,14 +264,16 @@ def change_image_objects_hsv(
128
  "targets_config must be a list of lists, each containing [target_name, hue, saturation_scale]." # noqa: E501
129
  )
130
  print("Change image objects hsv targets config:", targets_config)
131
- prompts = ". ".join(target[0] for target in targets_config)
132
 
133
- langsam_results = lang_sam_segment.remote(image_pil=image_pil, prompt=prompts)
134
- if not langsam_results:
 
 
 
135
  return image_pil
136
- input_labels = [target[0] for target in targets_config]
137
- output_labels = langsam_results[0]["labels"]
138
- scores = langsam_results[0]["scores"]
139
 
140
  img_array = np.array(image_pil)
141
  img_hsv = cv2.cvtColor(img_array, cv2.COLOR_RGB2HSV).astype(np.float32)
@@ -144,13 +282,14 @@ def change_image_objects_hsv(
144
  if not label or label == "":
145
  print("Skipping empty label.")
146
  continue
147
- input_label, score, _ = process.extractOne(label, input_labels)
148
- input_label_idx = input_labels.index(input_label)
149
-
 
150
  target_rgb = targets_config[input_label_idx][1:]
151
  target_hsv = cv2.cvtColor(np.uint8([[target_rgb]]), cv2.COLOR_RGB2HSV)[0][0]
152
 
153
- mask = langsam_results[0]["masks"][idx].astype(bool)
154
  h, s, v = cv2.split(img_hsv)
155
  # Convert all channels to float32 for consistent processing
156
  h = h.astype(np.float32)
@@ -168,9 +307,9 @@ def change_image_objects_hsv(
168
  scale_s = target_s / mean_s if mean_s > 0 else 1.0
169
  scale_v = target_v / mean_v if mean_v > 0 else 1.0
170
 
171
- scale_s = np.clip(scale_s, 0.8, 1.2)
172
  scale_v = np.clip(scale_v, 0.8, 1.2)
173
-
174
  # Apply changes only in mask
175
  h[mask] = target_hue
176
  s = s.astype(np.float32)
@@ -224,18 +363,16 @@ def change_image_objects_lab(
224
 
225
  print("change image objects lab targets config:", targets_config)
226
 
227
- prompts = ". ".join(target[0] for target in targets_config)
228
 
229
- langsam_results = lang_sam_segment.remote(
230
  image_pil=image_pil,
231
- prompt=prompts,
232
  )
233
- if not langsam_results:
234
  return image_pil
235
 
236
- input_labels = [target[0] for target in targets_config]
237
- output_labels = langsam_results[0]["labels"]
238
- scores = langsam_results[0]["scores"]
239
 
240
  img_array = np.array(image_pil)
241
  img_lab = cv2.cvtColor(img_array, cv2.COLOR_RGB2Lab).astype(np.float32)
@@ -244,13 +381,17 @@ def change_image_objects_lab(
244
  if not label or label == "":
245
  print("Skipping empty label.")
246
  continue
247
- input_label, score, _ = process.extractOne(label, input_labels)
248
- input_label_idx = input_labels.index(input_label)
 
 
 
 
249
 
250
  new_a = targets_config[input_label_idx][1]
251
  new_b = targets_config[input_label_idx][2]
252
 
253
- mask = langsam_results[0]["masks"][idx]
254
  mask_bool = mask.astype(bool)
255
 
256
  img_lab[mask_bool, 1] = new_a
@@ -298,49 +439,46 @@ def apply_mosaic_with_bool_mask(
298
  )
299
  def preserve_privacy(
300
  image_pil: Image.Image,
301
- prompt: str,
302
  privacy_strength: int = 15,
303
  ) -> Image.Image:
304
  """
305
  Preserves privacy in an image by applying a mosaic effect to specified objects.
306
  """
307
- print(f"Preserving privacy for prompt: {prompt} with strength {privacy_strength}")
308
-
309
- langsam_results = lang_sam_segment.remote(
 
 
310
  image_pil=image_pil,
311
- prompt=prompt,
312
- box_threshold=0.35,
313
- text_threshold=0.40,
314
  )
315
- if not langsam_results:
316
  return image_pil
317
 
318
  img_array = np.array(image_pil)
319
 
320
- for result in langsam_results:
321
- print(f"result: {result}")
322
-
323
- for i, mask in enumerate(result["masks"]):
324
- if "mask_scores" in result:
325
- if (
326
- hasattr(result["mask_scores"], "shape")
327
- and result["mask_scores"].ndim > 0
328
- ):
329
- mask_score = result["mask_scores"][i]
330
- else:
331
- mask_score = result["mask_scores"]
332
- if mask_score < 0.6:
333
- print(f"Skipping mask {i + 1}/{len(result['masks'])} -> low score.")
334
- continue
335
- print(
336
- f"Processing mask {i + 1}/{len(result['masks'])} Mask score: {mask_score}" # noqa: E501
337
- )
338
 
339
- mask_bool = mask.astype(bool)
 
 
340
 
341
- img_array = apply_mosaic_with_bool_mask.remote(
342
- img_array, mask_bool, privacy_strength
343
- )
 
 
 
 
 
 
 
 
 
 
 
344
 
345
  output_image_pil = Image.fromarray(img_array)
346
 
@@ -354,14 +492,14 @@ def preserve_privacy(
354
  timeout=60 * 2,
355
  )
356
  def remove_background(image_pil: Image.Image) -> Image.Image:
357
- from ben2 import BEN_Base
358
- import torch
359
 
360
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
361
  print(f"Using device: {device}")
362
  print("type of image_pil:", type(image_pil))
363
  model = BEN_Base.from_pretrained("PramaLLC/BEN2")
364
- model.to(device).eval()
365
 
366
  output_image = model.inference(
367
  image_pil,
 
5
  import modal
6
  import numpy as np
7
  from PIL import Image
 
8
 
9
  app = modal.App("ImageAlfred")
10
 
 
29
  "TORCH_HOME": TORCH_HOME,
30
  }
31
  )
32
+ .apt_install(
33
+ "git",
34
+ )
35
  .pip_install(
36
  "huggingface-hub",
37
  "hf_transfer",
38
  "Pillow",
39
  "numpy",
40
+ "transformers",
41
  "opencv-contrib-python-headless",
 
42
  gpu="A10G",
43
  )
44
  .pip_install(
 
47
  index_url="https://download.pytorch.org/whl/cu124",
48
  gpu="A10G",
49
  )
50
+ .pip_install("git+https://github.com/openai/CLIP.git", gpu="A10G")
51
+ .pip_install("git+https://github.com/facebookresearch/sam2.git", gpu="A10G")
 
 
52
  .pip_install(
53
  "git+https://github.com/PramaLLC/BEN2.git#egg=ben2",
54
  gpu="A10G",
 
57
 
58
 
59
  @app.function(
60
+ image=image,
61
+ gpu="A10G",
62
+ volumes={volume_path: volume},
63
+ timeout=60 * 3,
64
+ )
65
+ def prompt_segment(
66
+ image_pil: Image.Image,
67
+ prompts: list[str],
68
+ ) -> list[dict]:
69
+ clip_results = clip.remote(image_pil, prompts)
70
+
71
+ if not clip_results:
72
+ print("No boxes returned from CLIP.")
73
+ return None
74
+
75
+ boxes = np.array(clip_results["boxes"])
76
+
77
+ sam_result_masks, sam_result_scores = sam2.remote(image_pil=image_pil, boxes=boxes)
78
+
79
+ print(f"sam_result_mask {sam_result_masks}")
80
+
81
+ if not sam_result_masks.any():
82
+ print("No masks or scores returned from SAM2.")
83
+ return None
84
+
85
+ if sam_result_masks.ndim == 3:
86
+ # If the masks are in 3D, we need to convert them to 4D
87
+ sam_result_masks = [sam_result_masks]
88
+
89
+ results = {
90
+ "labels": clip_results["labels"],
91
+ "boxes": boxes,
92
+ "clip_scores": clip_results["scores"],
93
+ "sam_masking_scores": sam_result_scores,
94
+ "masks": sam_result_masks,
95
+ }
96
+ return results
97
+
98
+
99
+ @app.function(
100
+ image=image,
101
  gpu="A10G",
102
+ volumes={volume_path: volume},
103
+ timeout=60 * 3,
104
+ )
105
+ def sam2(image_pil: Image.Image, boxes: list[np.ndarray]) -> list[dict]:
106
+ import torch
107
+ from sam2.sam2_image_predictor import SAM2ImagePredictor
108
+
109
+ predictor = SAM2ImagePredictor.from_pretrained("facebook/sam2-hiera-large")
110
+
111
+ with torch.inference_mode(), torch.autocast("cuda", dtype=torch.bfloat16):
112
+ predictor.set_image(image_pil)
113
+ masks, scores, _ = predictor.predict(
114
+ point_coords=None,
115
+ point_labels=None,
116
+ box=boxes,
117
+ multimask_output=False,
118
+ )
119
+ return masks, scores
120
+
121
+
122
+ @app.function(
123
  image=image,
124
+ gpu="A10G",
125
  volumes={volume_path: volume},
 
126
  timeout=60 * 3,
127
  )
128
+ def clip(
129
  image_pil: Image.Image,
130
+ prompts: list[str],
131
+ ) -> list[dict]:
132
+ """
133
+ returns:
134
+ dict with keys each are lists:
135
+ - labels: str, the prompt used for the prediction
136
+ - scores: float, confidence score of the prediction
137
+ - boxes: np.array representing bounding box coordinates
138
+ """
139
+
140
+ from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
141
+ import torch
142
+
143
+ processor = CLIPSegProcessor.from_pretrained(
144
+ "CIDAS/clipseg-rd64-refined",
145
+ use_fast=True,
146
  )
147
+ model = CLIPSegForImageSegmentation.from_pretrained("CIDAS/clipseg-rd64-refined")
 
 
148
 
149
+ # Get original image dimensions
150
+ orig_width, orig_height = image_pil.size
 
 
 
 
 
151
 
152
+ inputs = processor(
153
+ text=prompts,
154
+ images=[image_pil] * len(prompts),
155
+ padding="max_length",
156
+ return_tensors="pt",
157
+ )
158
+ # predict
159
+ with torch.no_grad():
160
+ outputs = model(**inputs)
161
+ preds = outputs.logits.unsqueeze(1)
162
+
163
+ # Get the dimensions of the prediction output
164
+ pred_height, pred_width = preds.shape[-2:]
165
+
166
+ # Calculate scaling factors
167
+ width_scale = orig_width / pred_width
168
+ height_scale = orig_height / pred_height
169
+
170
+ labels = []
171
+ scores = []
172
+ boxes = []
173
+
174
+ # Process each prediction to find bounding boxes in high probability regions
175
+ for i, prompt in enumerate(prompts):
176
+ # Apply sigmoid to get probability map
177
+ pred_tensor = torch.sigmoid(preds[i][0])
178
+ # Convert tensor to numpy array
179
+ pred_np = pred_tensor.cpu().numpy()
180
+
181
+ # Convert to uint8 for OpenCV processing
182
+ heatmap = (pred_np * 255).astype(np.uint8)
183
+
184
+ # Apply threshold to find high probability regions
185
+ _, binary = cv2.threshold(heatmap, 127, 255, cv2.THRESH_BINARY)
186
+
187
+ # Find contours in thresholded image
188
+ contours, _ = cv2.findContours(
189
+ binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE
190
+ )
191
+
192
+ # Process each contour to get bounding boxes
193
+ for contour in contours:
194
+ # Skip very small contours that might be noise
195
+ if cv2.contourArea(contour) < 100: # Minimum area threshold
196
+ continue
197
+
198
+ # Get bounding box coordinates in prediction space
199
+ x, y, w, h = cv2.boundingRect(contour)
200
+
201
+ # Scale coordinates to original image dimensions
202
+ x_orig = int(x * width_scale)
203
+ y_orig = int(y * height_scale)
204
+ w_orig = int(w * width_scale)
205
+ h_orig = int(h * height_scale)
206
+
207
+ # Calculate confidence score based on average probability in the region
208
+ mask = np.zeros_like(pred_np)
209
+ cv2.drawContours(mask, [contour], 0, 1, -1)
210
+ confidence = float(np.mean(pred_np[mask == 1]))
211
+
212
+ labels.append(prompt)
213
+ scores.append(confidence)
214
+ boxes.append(
215
+ np.array(
216
+ [
217
+ x_orig,
218
+ y_orig,
219
+ x_orig + w_orig,
220
+ y_orig + h_orig,
221
+ ]
222
+ )
223
+ )
224
+
225
+ if labels == []:
226
+ return None
227
+
228
+ results = {
229
+ "labels": labels,
230
+ "scores": scores,
231
+ "boxes": boxes,
232
+ }
233
+ return results
234
 
235
 
236
  @app.function(
 
264
  "targets_config must be a list of lists, each containing [target_name, hue, saturation_scale]." # noqa: E501
265
  )
266
  print("Change image objects hsv targets config:", targets_config)
267
+ prompts = [target[0].strip() for target in targets_config]
268
 
269
+ prompt_segment_results = prompt_segment.remote(
270
+ image_pil=image_pil,
271
+ prompts=prompts,
272
+ )
273
+ if not prompt_segment_results:
274
  return image_pil
275
+
276
+ output_labels = prompt_segment_results["labels"]
 
277
 
278
  img_array = np.array(image_pil)
279
  img_hsv = cv2.cvtColor(img_array, cv2.COLOR_RGB2HSV).astype(np.float32)
 
282
  if not label or label == "":
283
  print("Skipping empty label.")
284
  continue
285
+ if label not in prompts:
286
+ print(f"Label '{label}' not found in prompts. Skipping.")
287
+ continue
288
+ input_label_idx = prompts.index(label)
289
  target_rgb = targets_config[input_label_idx][1:]
290
  target_hsv = cv2.cvtColor(np.uint8([[target_rgb]]), cv2.COLOR_RGB2HSV)[0][0]
291
 
292
+ mask = prompt_segment_results["masks"][idx][0].astype(bool)
293
  h, s, v = cv2.split(img_hsv)
294
  # Convert all channels to float32 for consistent processing
295
  h = h.astype(np.float32)
 
307
  scale_s = target_s / mean_s if mean_s > 0 else 1.0
308
  scale_v = target_v / mean_v if mean_v > 0 else 1.0
309
 
310
+ scale_s = np.clip(scale_s, 0.8, 1.2)
311
  scale_v = np.clip(scale_v, 0.8, 1.2)
312
+
313
  # Apply changes only in mask
314
  h[mask] = target_hue
315
  s = s.astype(np.float32)
 
363
 
364
  print("change image objects lab targets config:", targets_config)
365
 
366
+ prompts = [target[0].strip() for target in targets_config]
367
 
368
+ prompt_segment_results = prompt_segment.remote(
369
  image_pil=image_pil,
370
+ prompts=prompts,
371
  )
372
+ if not prompt_segment_results:
373
  return image_pil
374
 
375
+ output_labels = prompt_segment_results["labels"]
 
 
376
 
377
  img_array = np.array(image_pil)
378
  img_lab = cv2.cvtColor(img_array, cv2.COLOR_RGB2Lab).astype(np.float32)
 
381
  if not label or label == "":
382
  print("Skipping empty label.")
383
  continue
384
+
385
+ if label not in prompts:
386
+ print(f"Label '{label}' not found in prompts. Skipping.")
387
+ continue
388
+
389
+ input_label_idx = prompts.index(label)
390
 
391
  new_a = targets_config[input_label_idx][1]
392
  new_b = targets_config[input_label_idx][2]
393
 
394
+ mask = prompt_segment_results["masks"][idx][0]
395
  mask_bool = mask.astype(bool)
396
 
397
  img_lab[mask_bool, 1] = new_a
 
439
  )
440
  def preserve_privacy(
441
  image_pil: Image.Image,
442
+ prompts: str,
443
  privacy_strength: int = 15,
444
  ) -> Image.Image:
445
  """
446
  Preserves privacy in an image by applying a mosaic effect to specified objects.
447
  """
448
+ print(f"Preserving privacy for prompt: {prompts} with strength {privacy_strength}")
449
+ if isinstance(prompts, str):
450
+ prompts = [prompt.strip() for prompt in prompts.split(".")]
451
+ print(f"Parsed prompts: {prompts}")
452
+ prompt_segment_results = prompt_segment.remote(
453
  image_pil=image_pil,
454
+ prompts=prompts,
 
 
455
  )
456
+ if not prompt_segment_results:
457
  return image_pil
458
 
459
  img_array = np.array(image_pil)
460
 
461
+ for i, mask in enumerate(prompt_segment_results["masks"]):
462
+ mask_bool = mask[0].astype(bool)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
463
 
464
+ # Create kernel for morphological operations
465
+ kernel_size = 100
466
+ kernel = np.ones((kernel_size, kernel_size), np.uint8)
467
 
468
+ # Convert bool mask to uint8 for OpenCV operations
469
+ mask_uint8 = mask_bool.astype(np.uint8) * 255
470
+
471
+ # Apply dilation to slightly expand the mask area
472
+ mask_uint8 = cv2.dilate(mask_uint8, kernel, iterations=2)
473
+ # Optional: Apply erosion again to refine the mask
474
+ mask_uint8 = cv2.erode(mask_uint8, kernel, iterations=2)
475
+
476
+ # Convert back to boolean mask
477
+ mask_bool = mask_uint8 > 127
478
+
479
+ img_array = apply_mosaic_with_bool_mask.remote(
480
+ img_array, mask_bool, privacy_strength
481
+ )
482
 
483
  output_image_pil = Image.fromarray(img_array)
484
 
 
492
  timeout=60 * 2,
493
  )
494
  def remove_background(image_pil: Image.Image) -> Image.Image:
495
+ import torch # type: ignore
496
+ from ben2 import BEN_Base # type: ignore
497
 
498
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
499
  print(f"Using device: {device}")
500
  print("type of image_pil:", type(image_pil))
501
  model = BEN_Base.from_pretrained("PramaLLC/BEN2")
502
+ model.to(device).eval() # todo check if this should be outside the function
503
 
504
  output_image = model.inference(
505
  image_pil,
src/tools.py CHANGED
@@ -23,7 +23,7 @@ def remove_background(
23
  if not input_img:
24
  raise gr.Error("Input image cannot be None or empty.")
25
 
26
- func = modal.Function.from_name("ImageAlfred", "remove_background")
27
  output_pil = func.remote(
28
  image_pil=input_img,
29
  )
@@ -67,10 +67,10 @@ def privacy_preserve_image(
67
  if not valid_pattern.match(input_prompt):
68
  raise gr.Error("Input prompt must contain only letters, spaces, and dots.")
69
 
70
- func = modal.Function.from_name("ImageAlfred", "preserve_privacy")
71
  output_pil = func.remote(
72
  image_pil=input_img,
73
- prompt=input_prompt,
74
  privacy_strength=privacy_strength,
75
  )
76
 
@@ -136,14 +136,14 @@ def change_color_objects_hsv(
136
  raise gr.Error("Red must be an integer.")
137
  if item[1] < 0 or item[1] > 255:
138
  raise gr.Error("Red must be in the range [0, 255]")
139
-
140
  try:
141
  item[2] = int(item[2])
142
  except ValueError:
143
  raise gr.Error("Green must be an integer.")
144
  if item[2] < 0 or item[2] > 255:
145
  raise gr.Error("Green must be in the range [0, 255]")
146
-
147
  try:
148
  item[3] = int(item[3])
149
  except ValueError:
@@ -153,7 +153,7 @@ def change_color_objects_hsv(
153
 
154
  print("after processing input:", user_input)
155
 
156
- func = modal.Function.from_name("ImageAlfred", "change_image_objects_hsv")
157
  output_pil = func.remote(image_pil=input_img, targets_config=user_input)
158
 
159
  if output_pil is None:
@@ -248,7 +248,7 @@ def change_color_objects_lab(
248
  raise gr.Error("new B must be in the range [0, 255]")
249
 
250
  print("after processing input:", user_input)
251
- func = modal.Function.from_name("ImageAlfred", "change_image_objects_lab")
252
  output_pil = func.remote(image_pil=input_img, targets_config=user_input)
253
  if output_pil is None:
254
  raise ValueError("Received None from modal remote function.")
 
23
  if not input_img:
24
  raise gr.Error("Input image cannot be None or empty.")
25
 
26
+ func = modal.Function.from_name(modal_app_name, "remove_background")
27
  output_pil = func.remote(
28
  image_pil=input_img,
29
  )
 
67
  if not valid_pattern.match(input_prompt):
68
  raise gr.Error("Input prompt must contain only letters, spaces, and dots.")
69
 
70
+ func = modal.Function.from_name(modal_app_name, "preserve_privacy")
71
  output_pil = func.remote(
72
  image_pil=input_img,
73
+ prompts=input_prompt,
74
  privacy_strength=privacy_strength,
75
  )
76
 
 
136
  raise gr.Error("Red must be an integer.")
137
  if item[1] < 0 or item[1] > 255:
138
  raise gr.Error("Red must be in the range [0, 255]")
139
+
140
  try:
141
  item[2] = int(item[2])
142
  except ValueError:
143
  raise gr.Error("Green must be an integer.")
144
  if item[2] < 0 or item[2] > 255:
145
  raise gr.Error("Green must be in the range [0, 255]")
146
+
147
  try:
148
  item[3] = int(item[3])
149
  except ValueError:
 
153
 
154
  print("after processing input:", user_input)
155
 
156
+ func = modal.Function.from_name(modal_app_name, "change_image_objects_hsv")
157
  output_pil = func.remote(image_pil=input_img, targets_config=user_input)
158
 
159
  if output_pil is None:
 
248
  raise gr.Error("new B must be in the range [0, 255]")
249
 
250
  print("after processing input:", user_input)
251
+ func = modal.Function.from_name(modal_app_name, "change_image_objects_lab")
252
  output_pil = func.remote(image_pil=input_img, targets_config=user_input)
253
  if output_pil is None:
254
  raise ValueError("Received None from modal remote function.")