sagar007 commited on
Commit
23fa119
·
verified ·
1 Parent(s): eba2946

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +212 -130
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import gradio as gr
2
  import torch
3
- from transformers import AutoProcessor, AutoModel
4
  from PIL import Image, ImageDraw, ImageFont
5
  import numpy as np
6
  import random
@@ -12,9 +12,11 @@ import traceback # For detailed error printing
12
 
13
  # Device Selection
14
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 
 
15
  print(f"Using device: {DEVICE}")
16
 
17
- # --- CLIP Setup ---
18
  CLIP_MODEL_ID = "openai/clip-vit-base-patch32"
19
  clip_processor = None
20
  clip_model = None
@@ -22,17 +24,26 @@ clip_model = None
22
  def load_clip_model():
23
  global clip_processor, clip_model
24
  if clip_processor is None:
25
- print(f"Loading CLIP processor: {CLIP_MODEL_ID}...")
26
- clip_processor = AutoProcessor.from_pretrained(CLIP_MODEL_ID)
27
- print("CLIP processor loaded.")
 
 
 
 
28
  if clip_model is None:
29
- print(f"Loading CLIP model: {CLIP_MODEL_ID}...")
30
- clip_model = AutoModel.from_pretrained(CLIP_MODEL_ID).to(DEVICE)
31
- print(f"CLIP model loaded to {DEVICE}.")
 
 
 
 
 
 
32
 
33
  # --- FastSAM Setup ---
34
  FASTSAM_CHECKPOINT = "FastSAM-s.pt"
35
- # Use the official model hub repo URL
36
  FASTSAM_CHECKPOINT_URL = f"https://huggingface.co/CASIA-IVA-Lab/FastSAM-s/resolve/main/{FASTSAM_CHECKPOINT}"
37
 
38
  fastsam_model = None
@@ -53,6 +64,7 @@ def check_and_import_fastsam():
53
  fastsam_lib_imported = False
54
  except Exception as e:
55
  print(f"An unexpected error occurred during fastsam import: {e}")
 
56
  fastsam_lib_imported = False
57
  return fastsam_lib_imported
58
 
@@ -66,168 +78,210 @@ def download_fastsam_weights():
66
  except Exception as e:
67
  print(f"Error downloading FastSAM weights: {e}")
68
  print("Please ensure the URL is correct and reachable, or manually place the weights file.")
69
- # Attempt to remove partially downloaded file if exists
70
  if os.path.exists(FASTSAM_CHECKPOINT):
71
- try:
72
- os.remove(FASTSAM_CHECKPOINT)
73
- except OSError:
74
- pass # Ignore removal errors
75
  return False
76
  return os.path.exists(FASTSAM_CHECKPOINT)
77
 
78
  def load_fastsam_model():
79
  global fastsam_model
80
  if fastsam_model is None:
81
- if not check_and_import_fastsam(): # Check import first
82
  print("Cannot load FastSAM model because the library couldn't be imported.")
83
- return # Exit if import failed
84
 
85
- if download_fastsam_weights(): # Check download/existence second
86
  try:
87
- # FastSAM class should be available via globals() now
88
  print(f"Loading FastSAM model: {FASTSAM_CHECKPOINT}...")
89
  fastsam_model = FastSAM(FASTSAM_CHECKPOINT)
90
- print(f"FastSAM model loaded.") # Device handled internally by FastSAM
 
 
 
91
  except Exception as e:
92
  print(f"Error loading FastSAM model: {e}")
93
  traceback.print_exc()
94
  else:
95
  print("FastSAM weights not found or download failed. Cannot load model.")
 
96
 
97
 
98
  # --- Processing Functions ---
99
 
 
100
  # CLIP Zero-Shot Classification Function
101
  def run_clip_zero_shot(image: Image.Image, text_labels: str):
 
102
  if clip_model is None or clip_processor is None:
103
- load_clip_model() # Attempt to load if not already loaded
104
- if clip_model is None:
105
- return "Error: CLIP Model not loaded. Check logs.", None
106
 
107
- if image is None:
108
- return "Please upload an image.", None # Return None for the image display
109
- if not text_labels:
110
- # Return empty results but display the uploaded image
111
- return {}, image
112
 
113
- labels = [label.strip() for label in text_labels.split(',') if label.strip()] # Ensure non-empty labels
114
- if not labels:
115
- # Return empty results but display the uploaded image
116
- return {}, image
117
 
118
  print(f"Running CLIP zero-shot classification with labels: {labels}")
119
-
120
  try:
121
- # Ensure image is RGB
122
- if image.mode != "RGB":
123
- image = image.convert("RGB")
124
-
125
  inputs = clip_processor(text=labels, images=image, return_tensors="pt", padding=True).to(DEVICE)
126
-
127
  with torch.no_grad():
128
  outputs = clip_model(**inputs)
129
- logits_per_image = outputs.logits_per_image
130
- probs = logits_per_image.softmax(dim=1)
131
-
132
  print("CLIP processing complete.")
133
-
134
  confidences = {labels[i]: float(probs[0, i].item()) for i in range(len(labels))}
135
- # Return results and the original image used for prediction
136
  return confidences, image
137
-
138
  except Exception as e:
139
  print(f"Error during CLIP processing: {e}")
140
  traceback.print_exc()
141
- # Return error message and the original image
142
  return f"An error occurred during CLIP: {e}", image
143
 
144
-
145
- # FastSAM Segmentation Function
146
  def run_fastsam_segmentation(image_pil: Image.Image, conf_threshold: float = 0.4, iou_threshold: float = 0.9):
147
- # Ensure model is loaded or attempt to load
148
- if fastsam_model is None:
149
- load_fastsam_model()
150
- if fastsam_model is None:
151
- # Return error message string for the image component (Gradio handles this)
152
- return "Error: FastSAM Model not loaded. Check logs."
153
- # Ensure library was imported
154
  if not fastsam_lib_imported:
155
- return "Error: FastSAM library not available. Cannot run segmentation."
 
 
 
 
 
 
156
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
157
  if image_pil is None:
158
- return "Please upload an image."
 
 
159
 
160
- print("Running FastSAM segmentation...")
 
 
 
 
161
 
162
  try:
163
- # Ensure image is RGB
164
  if image_pil.mode != "RGB":
165
  image_pil = image_pil.convert("RGB")
166
-
167
  image_np_rgb = np.array(image_pil)
168
 
169
- # Run FastSAM inference
 
170
  everything_results = fastsam_model(
171
- image_np_rgb,
172
- device=DEVICE,
173
- retina_masks=True,
174
- imgsz=640,
175
- conf=conf_threshold,
176
- iou=iou_threshold,
177
  )
178
 
179
- # FastSAMPrompt should be available via globals() if import succeeded
180
  prompt_process = FastSAMPrompt(image_np_rgb, everything_results, device=DEVICE)
181
- ann = prompt_process.everything_prompt()
182
 
183
- print(f"FastSAM found {len(ann[0]['masks']) if ann and ann[0] and 'masks' in ann[0] else 0} masks.")
184
-
185
- # --- Plotting Masks on Image ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
  output_image = image_pil.copy()
187
- if ann and ann[0] is not None and 'masks' in ann[0] and len(ann[0]['masks']) > 0:
188
- masks = ann[0]['masks'].cpu().numpy() # (N, H, W) boolean
189
 
190
- overlay = Image.new('RGBA', output_image.size, (0, 0, 0, 0))
191
- draw = ImageDraw.Draw(overlay)
 
192
 
193
- for i in range(masks.shape[0]):
194
- mask = masks[i]
195
- color = (random.randint(50, 255), random.randint(50, 255), random.randint(50, 255), 128) # RGBA
196
- mask_image = Image.fromarray((mask * 255).astype(np.uint8), mode='L')
197
- draw.bitmap((0,0), mask_image, fill=color)
198
 
199
- output_image = Image.alpha_composite(output_image.convert('RGBA'), overlay).convert('RGB')
 
200
 
201
- print("FastSAM processing and plotting complete.")
202
- # *** FIX: Return ONLY the output image for the single Image component ***
203
- return output_image
 
 
 
 
 
 
 
204
 
205
- except NameError as ne:
206
- print(f"NameError during FastSAM processing: {ne}. Was the fastsam library imported correctly?")
207
- traceback.print_exc()
208
- return f"A NameError occurred: {ne}. Check library import."
209
  except Exception as e:
210
- print(f"Error during FastSAM processing: {e}")
211
  traceback.print_exc()
212
- return f"An error occurred during FastSAM: {e}"
213
 
214
 
215
  # --- Gradio Interface ---
216
 
217
- # Pre-load models on startup (optional but good for performance)
218
  print("Attempting to preload models...")
219
- load_clip_model()
220
- load_fastsam_model() # This will now also attempt download/check import
221
  print("Preloading finished (or attempted).")
222
 
223
 
224
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
225
  gr.Markdown("# CLIP & FastSAM Demo")
226
- gr.Markdown("Explore Zero-Shot Classification with CLIP and 'Segment Anything' with FastSAM.")
227
 
228
  with gr.Tabs():
229
- # --- CLIP Tab ---
230
  with gr.TabItem("CLIP Zero-Shot Classification"):
 
231
  gr.Markdown("Upload an image and provide comma-separated candidate labels (e.g., 'cat, dog, car'). CLIP will predict the probability of the image matching each label.")
232
  with gr.Row():
233
  with gr.Column(scale=1):
@@ -237,7 +291,6 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
237
  with gr.Column(scale=1):
238
  clip_output_label = gr.Label(label="Classification Probabilities")
239
  clip_output_image_display = gr.Image(type="pil", label="Input Image Preview")
240
-
241
  clip_button.click(
242
  run_clip_zero_shot,
243
  inputs=[clip_input_image, clip_text_labels],
@@ -247,56 +300,88 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
247
  examples=[
248
  ["examples/astronaut.jpg", "astronaut, moon, rover, mountain"],
249
  ["examples/dog_bike.jpg", "dog, bicycle, person, park, grass"],
250
- ["examples/clip_logo.png", "logo, text, graphics, abstract art"], # Added another example
251
  ],
252
  inputs=[clip_input_image, clip_text_labels],
253
- outputs=[clip_output_label, clip_output_image_display],
254
- fn=run_clip_zero_shot,
255
- cache_examples=False,
256
  )
257
 
258
- # --- FastSAM Tab ---
259
- with gr.TabItem("FastSAM Segmentation"):
260
- gr.Markdown("Upload an image. FastSAM will attempt to segment all objects/regions in the image.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
261
  with gr.Row():
262
  with gr.Column(scale=1):
263
- fastsam_input_image = gr.Image(type="pil", label="Input Image")
264
- with gr.Row():
265
- fastsam_conf = gr.Slider(minimum=0.1, maximum=1.0, value=0.4, step=0.05, label="Confidence Threshold")
266
- fastsam_iou = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="IoU Threshold")
267
- fastsam_button = gr.Button("Run FastSAM Segmentation", variant="primary")
 
268
  with gr.Column(scale=1):
269
- fastsam_output_image = gr.Image(type="pil", label="Segmented Image")
 
270
 
271
- fastsam_button.click(
272
- run_fastsam_segmentation,
273
- inputs=[fastsam_input_image, fastsam_conf, fastsam_iou],
274
- # Output is now correctly mapped to the single component
275
- outputs=[fastsam_output_image]
276
  )
277
  gr.Examples(
278
  examples=[
279
- ["examples/dogs.jpg", 0.4, 0.9],
280
- ["examples/fruits.jpg", 0.5, 0.8],
281
- ["examples/lion.jpg", 0.45, 0.9], # Added another example
 
 
282
  ],
283
- inputs=[fastsam_input_image, fastsam_conf, fastsam_iou],
284
- outputs=[fastsam_output_image],
285
- fn=run_fastsam_segmentation,
286
  cache_examples=False,
287
  )
288
 
289
- # Add example images (optional, but helpful)
 
290
  if not os.path.exists("examples"):
291
  os.makedirs("examples")
292
  print("Created 'examples' directory. Attempting to download sample images...")
293
  example_files = {
294
- "astronaut.jpg": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/d astronaut_-_St._Jean_Bay.jpg/640px-Astronaut_-_St._Jean_Bay.jpg", # Find suitable public domain/CC image
295
- "dog_bike.jpg": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/gradio/outputs_multimodal.jpg", # Using a relevant example from HF
296
- "clip_logo.png": "https://raw.githubusercontent.com/openai/CLIP/main/CLIP.png",
297
- "dogs.jpg": "https://raw.githubusercontent.com/ultralytics/assets/main/im/image8.jpg", # From Ultralytics assets
298
- "fruits.jpg": "https://raw.githubusercontent.com/ultralytics/assets/main/im/image9.jpg", # From Ultralytics assets
299
- "lion.jpg": "https://huggingface.co/spaces/gradio/image-segmentation/resolve/main/images/lion.jpg"
 
300
  }
301
  for filename, url in example_files.items():
302
  filepath = os.path.join("examples", filename)
@@ -311,7 +396,4 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
311
 
312
  # Launch the Gradio app
313
  if __name__ == "__main__":
314
- # share=True is primarily for local testing to get a public link.
315
- # Not needed/used when deploying on Hugging Face Spaces.
316
- # debug=True is helpful for development. Set to False for production.
317
- demo.launch(debug=True)
 
1
  import gradio as gr
2
  import torch
3
+ from transformers import AutoProcessor, AutoModel # Keep CLIP for potential future use or if FastSAM's text prompt isn't enough
4
  from PIL import Image, ImageDraw, ImageFont
5
  import numpy as np
6
  import random
 
12
 
13
  # Device Selection
14
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
15
+ # Force CPU if CUDA fails or isn't desired (sometimes needed on Spaces free tier)
16
+ # DEVICE = "cpu"
17
  print(f"Using device: {DEVICE}")
18
 
19
+ # --- CLIP Setup (Kept in case needed, but FastSAM's method is primary now) ---
20
  CLIP_MODEL_ID = "openai/clip-vit-base-patch32"
21
  clip_processor = None
22
  clip_model = None
 
24
  def load_clip_model():
25
  global clip_processor, clip_model
26
  if clip_processor is None:
27
+ try:
28
+ print(f"Loading CLIP processor: {CLIP_MODEL_ID}...")
29
+ clip_processor = AutoProcessor.from_pretrained(CLIP_MODEL_ID)
30
+ print("CLIP processor loaded.")
31
+ except Exception as e:
32
+ print(f"Error loading CLIP processor: {e}")
33
+ return False # Indicate failure
34
  if clip_model is None:
35
+ try:
36
+ print(f"Loading CLIP model: {CLIP_MODEL_ID}...")
37
+ clip_model = AutoModel.from_pretrained(CLIP_MODEL_ID).to(DEVICE)
38
+ print(f"CLIP model loaded to {DEVICE}.")
39
+ except Exception as e:
40
+ print(f"Error loading CLIP model: {e}")
41
+ return False # Indicate failure
42
+ return True # Indicate success
43
+
44
 
45
  # --- FastSAM Setup ---
46
  FASTSAM_CHECKPOINT = "FastSAM-s.pt"
 
47
  FASTSAM_CHECKPOINT_URL = f"https://huggingface.co/CASIA-IVA-Lab/FastSAM-s/resolve/main/{FASTSAM_CHECKPOINT}"
48
 
49
  fastsam_model = None
 
64
  fastsam_lib_imported = False
65
  except Exception as e:
66
  print(f"An unexpected error occurred during fastsam import: {e}")
67
+ traceback.print_exc()
68
  fastsam_lib_imported = False
69
  return fastsam_lib_imported
70
 
 
78
  except Exception as e:
79
  print(f"Error downloading FastSAM weights: {e}")
80
  print("Please ensure the URL is correct and reachable, or manually place the weights file.")
 
81
  if os.path.exists(FASTSAM_CHECKPOINT):
82
+ try: os.remove(FASTSAM_CHECKPOINT)
83
+ except OSError: pass
 
 
84
  return False
85
  return os.path.exists(FASTSAM_CHECKPOINT)
86
 
87
  def load_fastsam_model():
88
  global fastsam_model
89
  if fastsam_model is None:
90
+ if not check_and_import_fastsam():
91
  print("Cannot load FastSAM model because the library couldn't be imported.")
92
+ return False # Indicate failure
93
 
94
+ if download_fastsam_weights():
95
  try:
 
96
  print(f"Loading FastSAM model: {FASTSAM_CHECKPOINT}...")
97
  fastsam_model = FastSAM(FASTSAM_CHECKPOINT)
98
+ # The FastSAM model itself doesn't need explicit .to(DEVICE)
99
+ # It seems to handle device selection internally or via the prompt process
100
+ print(f"FastSAM model loaded.")
101
+ return True # Indicate success
102
  except Exception as e:
103
  print(f"Error loading FastSAM model: {e}")
104
  traceback.print_exc()
105
  else:
106
  print("FastSAM weights not found or download failed. Cannot load model.")
107
+ return fastsam_model is not None # Return True if already loaded or loaded successfully
108
 
109
 
110
  # --- Processing Functions ---
111
 
112
+ # (Keep run_clip_zero_shot and run_fastsam_segmentation as they were for the other tabs)
113
  # CLIP Zero-Shot Classification Function
114
  def run_clip_zero_shot(image: Image.Image, text_labels: str):
115
+ # Load CLIP if needed
116
  if clip_model is None or clip_processor is None:
117
+ if not load_clip_model():
118
+ return "Error: CLIP Model could not be loaded. Check logs.", None
 
119
 
120
+ if image is None: return "Please upload an image.", None
121
+ if not text_labels: return {}, image # Return empty dict, show image
 
 
 
122
 
123
+ labels = [label.strip() for label in text_labels.split(',') if label.strip()]
124
+ if not labels: return {}, image
 
 
125
 
126
  print(f"Running CLIP zero-shot classification with labels: {labels}")
 
127
  try:
128
+ if image.mode != "RGB": image = image.convert("RGB")
 
 
 
129
  inputs = clip_processor(text=labels, images=image, return_tensors="pt", padding=True).to(DEVICE)
 
130
  with torch.no_grad():
131
  outputs = clip_model(**inputs)
132
+ probs = outputs.logits_per_image.softmax(dim=1)
 
 
133
  print("CLIP processing complete.")
 
134
  confidences = {labels[i]: float(probs[0, i].item()) for i in range(len(labels))}
 
135
  return confidences, image
 
136
  except Exception as e:
137
  print(f"Error during CLIP processing: {e}")
138
  traceback.print_exc()
 
139
  return f"An error occurred during CLIP: {e}", image
140
 
141
+ # FastSAM Everything Segmentation Function (for the second tab)
 
142
  def run_fastsam_segmentation(image_pil: Image.Image, conf_threshold: float = 0.4, iou_threshold: float = 0.9):
143
+ if not load_fastsam_model():
144
+ return "Error: FastSAM Model not loaded. Check logs."
 
 
 
 
 
145
  if not fastsam_lib_imported:
146
+ return "Error: FastSAM library not available."
147
+ if image_pil is None: return "Please upload an image."
148
+
149
+ print("Running FastSAM 'segment everything'...")
150
+ try:
151
+ if image_pil.mode != "RGB": image_pil = image_pil.convert("RGB")
152
+ image_np_rgb = np.array(image_pil)
153
 
154
+ everything_results = fastsam_model(
155
+ image_np_rgb, device=DEVICE, retina_masks=True, imgsz=640,
156
+ conf=conf_threshold, iou=iou_threshold,
157
+ )
158
+ prompt_process = FastSAMPrompt(image_np_rgb, everything_results, device=DEVICE)
159
+ ann = prompt_process.everything_prompt()
160
+ print(f"FastSAM 'everything' found {len(ann[0]['masks']) if ann and ann[0] and 'masks' in ann[0] else 0} masks.")
161
+
162
+ # Plotting
163
+ output_image = image_pil.copy()
164
+ if ann and ann[0] is not None and 'masks' in ann[0] and len(ann[0]['masks']) > 0:
165
+ masks = ann[0]['masks'].cpu().numpy()
166
+ overlay = Image.new('RGBA', output_image.size, (0, 0, 0, 0))
167
+ draw = ImageDraw.Draw(overlay)
168
+ for mask in masks:
169
+ color = (random.randint(50, 255), random.randint(50, 255), random.randint(50, 255), 128)
170
+ mask_image = Image.fromarray((mask * 255).astype(np.uint8), mode='L')
171
+ draw.bitmap((0, 0), mask_image, fill=color)
172
+ output_image = Image.alpha_composite(output_image.convert('RGBA'), overlay).convert('RGB')
173
+
174
+ print("FastSAM 'everything' processing complete.")
175
+ return output_image
176
+
177
+ except Exception as e:
178
+ print(f"Error during FastSAM 'everything' processing: {e}")
179
+ traceback.print_exc()
180
+ return f"An error occurred during FastSAM 'everything': {e}"
181
+
182
+
183
+ # --- NEW: Text-Prompted Segmentation Function ---
184
+ def run_text_prompted_segmentation(image_pil: Image.Image, text_prompts: str, conf_threshold: float = 0.4, iou_threshold: float = 0.9):
185
+ """Segments objects based on text prompts."""
186
+ if not load_fastsam_model():
187
+ return "Error: FastSAM Model not loaded. Check logs.", "No prompts provided."
188
+ if not fastsam_lib_imported:
189
+ return "Error: FastSAM library not available.", "FastSAM library error."
190
  if image_pil is None:
191
+ return "Please upload an image.", "No image provided."
192
+ if not text_prompts:
193
+ return image_pil, "Please enter text prompts (e.g., 'person, dog')." # Return original image and message
194
 
195
+ prompts = [p.strip() for p in text_prompts.split(',') if p.strip()]
196
+ if not prompts:
197
+ return image_pil, "No valid text prompts entered."
198
+
199
+ print(f"Running FastSAM text-prompted segmentation for: {prompts}")
200
 
201
  try:
 
202
  if image_pil.mode != "RGB":
203
  image_pil = image_pil.convert("RGB")
 
204
  image_np_rgb = np.array(image_pil)
205
 
206
+ # 1. Run FastSAM once to get all potential results
207
+ # NOTE: We might optimize later, but this is the standard way FastSAMPrompt works.
208
  everything_results = fastsam_model(
209
+ image_np_rgb, device=DEVICE, retina_masks=True, imgsz=640,
210
+ conf=conf_threshold, iou=iou_threshold, verbose=False # Less console spam
 
 
 
 
211
  )
212
 
213
+ # 2. Create the prompt processor
214
  prompt_process = FastSAMPrompt(image_np_rgb, everything_results, device=DEVICE)
 
215
 
216
+ # 3. Use text_prompt for each prompt and collect masks
217
+ all_matching_masks = []
218
+ found_prompts = []
219
+
220
+ for text in prompts:
221
+ print(f" Processing prompt: '{text}'")
222
+ # Ann is a list of dictionaries, one per image. We have one image.
223
+ # Each dict can have 'masks', 'bboxes', 'points'.
224
+ # text_prompt filters 'everything_results' based on CLIP-like similarity.
225
+ # It might return multiple masks if multiple instances match the text.
226
+ ann = prompt_process.text_prompt(text=text)
227
+
228
+ if ann and ann[0] is not None and 'masks' in ann[0] and len(ann[0]['masks']) > 0:
229
+ num_found = len(ann[0]['masks'])
230
+ print(f" Found {num_found} mask(s) matching '{text}'.")
231
+ found_prompts.append(f"{text} ({num_found})")
232
+ masks = ann[0]['masks'].cpu().numpy() # Get masks as numpy array (N, H, W)
233
+ all_matching_masks.extend(masks) # Add the numpy arrays to the list
234
+ else:
235
+ print(f" No masks found matching '{text}'.")
236
+ found_prompts.append(f"{text} (0)")
237
+
238
+ # 4. Plot the collected masks
239
  output_image = image_pil.copy()
240
+ status_message = f"Found segments for: {', '.join(found_prompts)}" if found_prompts else "No matching segments found for any prompt."
 
241
 
242
+ if not all_matching_masks:
243
+ print("No matching masks found for any prompt.")
244
+ return output_image, status_message # Return original image if nothing matched
245
 
246
+ # Convert list of (H, W) masks to a single (N, H, W) array for consistent processing
247
+ masks_np = np.stack(all_matching_masks, axis=0) # Shape (TotalMasks, H, W)
 
 
 
248
 
249
+ overlay = Image.new('RGBA', output_image.size, (0, 0, 0, 0))
250
+ draw = ImageDraw.Draw(overlay)
251
 
252
+ for i in range(masks_np.shape[0]):
253
+ mask = masks_np[i] # Shape (H, W), boolean
254
+ color = (random.randint(50, 255), random.randint(50, 255), random.randint(50, 255), 150) # RGBA with slightly more alpha
255
+ mask_image = Image.fromarray((mask * 255).astype(np.uint8), mode='L')
256
+ draw.bitmap((0, 0), mask_image, fill=color)
257
+
258
+ output_image = Image.alpha_composite(output_image.convert('RGBA'), overlay).convert('RGB')
259
+
260
+ print("FastSAM text-prompted processing complete.")
261
+ return output_image, status_message
262
 
 
 
 
 
263
  except Exception as e:
264
+ print(f"Error during FastSAM text-prompted processing: {e}")
265
  traceback.print_exc()
266
+ return f"An error occurred: {e}", "Error during processing."
267
 
268
 
269
  # --- Gradio Interface ---
270
 
 
271
  print("Attempting to preload models...")
272
+ # load_clip_model() # Load CLIP lazily if needed
273
+ load_fastsam_model() # Load FastSAM eagerly
274
  print("Preloading finished (or attempted).")
275
 
276
 
277
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
278
  gr.Markdown("# CLIP & FastSAM Demo")
279
+ gr.Markdown("Explore Zero-Shot Classification, 'Segment Everything', and Text-Prompted Segmentation.")
280
 
281
  with gr.Tabs():
282
+ # --- CLIP Tab (No changes) ---
283
  with gr.TabItem("CLIP Zero-Shot Classification"):
284
+ # ... (keep the existing layout and logic for CLIP) ...
285
  gr.Markdown("Upload an image and provide comma-separated candidate labels (e.g., 'cat, dog, car'). CLIP will predict the probability of the image matching each label.")
286
  with gr.Row():
287
  with gr.Column(scale=1):
 
291
  with gr.Column(scale=1):
292
  clip_output_label = gr.Label(label="Classification Probabilities")
293
  clip_output_image_display = gr.Image(type="pil", label="Input Image Preview")
 
294
  clip_button.click(
295
  run_clip_zero_shot,
296
  inputs=[clip_input_image, clip_text_labels],
 
300
  examples=[
301
  ["examples/astronaut.jpg", "astronaut, moon, rover, mountain"],
302
  ["examples/dog_bike.jpg", "dog, bicycle, person, park, grass"],
303
+ ["examples/clip_logo.png", "logo, text, graphics, abstract art"],
304
  ],
305
  inputs=[clip_input_image, clip_text_labels],
306
+ outputs=[clip_output_label, clip_output_image_display], fn=run_clip_zero_shot, cache_examples=False,
 
 
307
  )
308
 
309
+
310
+ # --- FastSAM Everything Tab (No changes) ---
311
+ with gr.TabItem("FastSAM Segment Everything"):
312
+ # ... (keep the existing layout and logic for segment everything) ...
313
+ gr.Markdown("Upload an image. FastSAM will attempt to segment all objects/regions in the image.")
314
+ with gr.Row():
315
+ with gr.Column(scale=1):
316
+ fastsam_input_image_all = gr.Image(type="pil", label="Input Image", elem_id="fastsam_input_all") # Unique elem_id if needed
317
+ with gr.Row():
318
+ fastsam_conf_all = gr.Slider(minimum=0.1, maximum=1.0, value=0.4, step=0.05, label="Confidence Threshold")
319
+ fastsam_iou_all = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="IoU Threshold")
320
+ fastsam_button_all = gr.Button("Run FastSAM Segmentation", variant="primary")
321
+ with gr.Column(scale=1):
322
+ fastsam_output_image_all = gr.Image(type="pil", label="Segmented Image", elem_id="fastsam_output_all")
323
+ fastsam_button_all.click(
324
+ run_fastsam_segmentation,
325
+ inputs=[fastsam_input_image_all, fastsam_conf_all, fastsam_iou_all],
326
+ outputs=[fastsam_output_image_all]
327
+ )
328
+ gr.Examples(
329
+ examples=[
330
+ ["examples/dogs.jpg", 0.4, 0.9],
331
+ ["examples/fruits.jpg", 0.5, 0.8],
332
+ ["examples/lion.jpg", 0.45, 0.9],
333
+ ],
334
+ inputs=[fastsam_input_image_all, fastsam_conf_all, fastsam_iou_all],
335
+ outputs=[fastsam_output_image_all], fn=run_fastsam_segmentation, cache_examples=False,
336
+ )
337
+
338
+ # --- NEW: Text-Prompted Segmentation Tab ---
339
+ with gr.TabItem("Text-Prompted Segmentation"):
340
+ gr.Markdown("Upload an image and provide comma-separated text prompts (e.g., 'person, dog, backpack'). FastSAM + CLIP (internally) will segment only the objects matching the text.")
341
  with gr.Row():
342
  with gr.Column(scale=1):
343
+ prompt_input_image = gr.Image(type="pil", label="Input Image")
344
+ prompt_text_input = gr.Textbox(label="Comma-Separated Text Prompts", placeholder="e.g., glasses, watch, t-shirt")
345
+ with gr.Row(): # Reuse confidence/IoU sliders if desired
346
+ prompt_conf = gr.Slider(minimum=0.1, maximum=1.0, value=0.4, step=0.05, label="Confidence Threshold")
347
+ prompt_iou = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="IoU Threshold")
348
+ prompt_button = gr.Button("Segment by Text", variant="primary")
349
  with gr.Column(scale=1):
350
+ prompt_output_image = gr.Image(type="pil", label="Text-Prompted Segmentation")
351
+ prompt_status_message = gr.Textbox(label="Status", interactive=False) # To show which prompts matched
352
 
353
+ prompt_button.click(
354
+ run_text_prompted_segmentation,
355
+ inputs=[prompt_input_image, prompt_text_input, prompt_conf, prompt_iou],
356
+ outputs=[prompt_output_image, prompt_status_message] # Map to image and status box
 
357
  )
358
  gr.Examples(
359
  examples=[
360
+ ["examples/dog_bike.jpg", "person, bicycle", 0.4, 0.9],
361
+ ["examples/astronaut.jpg", "person, helmet", 0.35, 0.9],
362
+ ["examples/dogs.jpg", "dog", 0.4, 0.9], # Should find multiple dogs
363
+ ["examples/fruits.jpg", "banana, apple", 0.5, 0.8],
364
+ ["examples/teacher.jpg", "person, glasses, blackboard", 0.4, 0.9], # Download this image or use another one with glasses/blackboard
365
  ],
366
+ inputs=[prompt_input_image, prompt_text_input, prompt_conf, prompt_iou],
367
+ outputs=[prompt_output_image, prompt_status_message],
368
+ fn=run_text_prompted_segmentation,
369
  cache_examples=False,
370
  )
371
 
372
+ # Ensure example images exist or are downloaded
373
+ # (Keep the existing example download logic, maybe add teacher.jpg if used in examples)
374
  if not os.path.exists("examples"):
375
  os.makedirs("examples")
376
  print("Created 'examples' directory. Attempting to download sample images...")
377
  example_files = {
378
+ "astronaut.jpg": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/d1/Astronaut_-_St._Jean_Bay.jpg/640px-Astronaut_-_St._Jean_Bay.jpg",
379
+ "dog_bike.jpg": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/gradio/outputs_multimodal.jpg",
380
+ "clip_logo.png": "https://raw.githubusercontent.com/openai/CLIP/main/CLIP.png",
381
+ "dogs.jpg": "https://raw.githubusercontent.com/ultralytics/assets/main/im/image8.jpg",
382
+ "fruits.jpg": "https://raw.githubusercontent.com/ultralytics/assets/main/im/image9.jpg",
383
+ "lion.jpg": "https://huggingface.co/spaces/gradio/image-segmentation/resolve/main/images/lion.jpg",
384
+ "teacher.jpg": "https://images.pexels.com/photos/848117/pexels-photo-848117.jpeg?auto=compress&cs=tinysrgb&w=600" # Example with glasses/board
385
  }
386
  for filename, url in example_files.items():
387
  filepath = os.path.join("examples", filename)
 
396
 
397
  # Launch the Gradio app
398
  if __name__ == "__main__":
399
+ demo.launch(debug=True) # debug=True is helpful locally