sagar007 commited on
Commit
0747bb5
·
verified ·
1 Parent(s): 22401e9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +312 -200
app.py CHANGED
@@ -7,11 +7,12 @@ import random
7
  import os
8
  import wget
9
  import traceback
 
10
 
11
  # --- Configuration & Model Loading ---
12
 
13
  # Device Selection with fallback
14
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # Simplified check
15
  print(f"Using device: {DEVICE}")
16
 
17
  # --- CLIP Setup ---
@@ -28,7 +29,7 @@ def load_clip_model():
28
  print("CLIP processor loaded.")
29
  except Exception as e:
30
  print(f"Error loading CLIP processor: {e}")
31
- traceback.print_exc() # Print traceback
32
  return False
33
  if clip_model is None:
34
  try:
@@ -37,7 +38,7 @@ def load_clip_model():
37
  print(f"CLIP model loaded to {DEVICE}.")
38
  except Exception as e:
39
  print(f"Error loading CLIP model: {e}")
40
- traceback.print_exc() # Print traceback
41
  return False
42
  return True
43
 
@@ -51,17 +52,37 @@ FastSAM = None # Define placeholders
51
  FastSAMPrompt = None # Define placeholders
52
 
53
  def check_and_import_fastsam():
54
- global fastsam_lib_imported, FastSAM, FastSAMPrompt # Make sure globals are modified
55
  if not fastsam_lib_imported:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  try:
57
- from fastsam import FastSAM as FastSAM_lib, FastSAMPrompt as FastSAMPrompt_lib # Use temp names
58
- FastSAM = FastSAM_lib # Assign to global
59
- FastSAMPrompt = FastSAMPrompt_lib # Assign to global
 
60
  fastsam_lib_imported = True
61
  print("fastsam library imported successfully.")
62
  except ImportError as e:
63
- print(f"Error: 'fastsam' library not found. Please install it: pip install git+https://github.com/CASIA-IVA-Lab/FastSAM.git")
64
- print(f"ImportError: {e}")
 
 
 
 
 
65
  fastsam_lib_imported = False
66
  except Exception as e:
67
  print(f"Unexpected error during fastsam import: {e}")
@@ -72,12 +93,20 @@ def check_and_import_fastsam():
72
  def download_fastsam_weights(retries=3):
73
  if not os.path.exists(FASTSAM_CHECKPOINT):
74
  print(f"Downloading FastSAM weights: {FASTSAM_CHECKPOINT} from {FASTSAM_CHECKPOINT_URL}...")
 
 
 
 
 
 
 
 
 
 
75
  for attempt in range(retries):
76
  try:
77
- # Ensure the directory exists if FASTSAM_CHECKPOINT includes a path
78
- os.makedirs(os.path.dirname(FASTSAM_CHECKPOINT) or '.', exist_ok=True)
79
  wget.download(FASTSAM_CHECKPOINT_URL, FASTSAM_CHECKPOINT)
80
- print("FastSAM weights downloaded.")
81
  return True # Return True on successful download
82
  except Exception as e:
83
  print(f"Attempt {attempt + 1}/{retries} failed to download FastSAM weights: {e}")
@@ -89,48 +118,53 @@ def download_fastsam_weights(retries=3):
89
  if attempt + 1 == retries:
90
  print("Failed to download weights after all attempts.")
91
  return False
92
- return False # Should not be reached if loop completes, but added for clarity
93
  else:
94
- print("FastSAM weights already exist.")
95
  return True # Weights exist
96
 
97
  def load_fastsam_model():
98
  global fastsam_model
99
  if fastsam_model is None:
 
100
  if not check_and_import_fastsam():
101
  print("Cannot load FastSAM model due to library import failure.")
102
  return False
103
- if download_fastsam_weights():
104
- # Ensure FastSAM class is available (might fail if import failed earlier but file exists)
105
- if FastSAM is None:
106
- print("FastSAM class not available, check import status.")
107
- return False
108
- try:
109
- print(f"Loading FastSAM model: {FASTSAM_CHECKPOINT}...")
110
- # Instantiate the imported class
111
- fastsam_model = FastSAM(FASTSAM_CHECKPOINT)
112
- # Move model to device *after* initialization (common practice)
113
- # Note: Check FastSAM docs if it needs explicit .to(DEVICE) or handles it internally
114
- # fastsam_model.model.to(DEVICE) # Example if needed, adjust based on FastSAM structure
115
- print("FastSAM model loaded.")
116
- return True
117
- except Exception as e:
118
- print(f"Error loading FastSAM model weights or initializing: {e}")
119
- traceback.print_exc()
120
- return False
121
- else:
122
- print("FastSAM weights not found or download failed.")
 
 
123
  return False
124
  # Model already loaded
 
125
  return True
126
 
127
  # --- Processing Functions ---
128
 
129
  def run_clip_zero_shot(image: Image.Image, text_labels: str):
130
- # Keep CLIP as is, seems less likely to be the primary issue
 
 
131
  if not isinstance(image, Image.Image):
132
- print(f"CLIP input is not a PIL Image, type: {type(image)}")
133
- # Try to convert if it's a numpy array (common from Gradio)
134
  if isinstance(image, np.ndarray):
135
  try:
136
  image = Image.fromarray(image)
@@ -141,18 +175,18 @@ def run_clip_zero_shot(image: Image.Image, text_labels: str):
141
  else:
142
  return "Error: Please provide a valid image.", None
143
 
 
144
  if clip_model is None or clip_processor is None:
145
  if not load_clip_model():
146
- # Return None for the image part on critical error
147
  return "Error: CLIP Model could not be loaded.", None
 
 
148
  if not text_labels:
149
- # Return empty dict and original image if no labels
150
- return {}, image
151
 
152
  labels = [label.strip() for label in text_labels.split(',') if label.strip()]
153
  if not labels:
154
- # Return empty dict and original image if no valid labels
155
- return {}, image
156
 
157
  print(f"Running CLIP zero-shot classification with labels: {labels}")
158
  try:
@@ -164,46 +198,42 @@ def run_clip_zero_shot(image: Image.Image, text_labels: str):
164
  inputs = clip_processor(text=labels, images=image, return_tensors="pt", padding=True).to(DEVICE)
165
  with torch.no_grad():
166
  outputs = clip_model(**inputs)
167
- # Calculate probabilities
168
- logits_per_image = outputs.logits_per_image # B x N_labels
169
- probs = logits_per_image.softmax(dim=1) # Softmax over labels
170
 
171
- # Create confidences dictionary
172
  confidences = {labels[i]: float(probs[0, i].item()) for i in range(len(labels))}
173
  print(f"CLIP Confidences: {confidences}")
174
- # Return confidences and the original (potentially converted) image
175
  return confidences, image
 
176
  except Exception as e:
177
  print(f"Error during CLIP processing: {e}")
178
  traceback.print_exc()
179
- # Return error message and None for image
180
  return f"Error during CLIP processing: {e}", None
181
 
182
 
183
  def run_fastsam_segmentation(image_pil: Image.Image, conf_threshold: float = 0.4, iou_threshold: float = 0.9):
184
- # Add input type check
 
 
185
  if not isinstance(image_pil, Image.Image):
186
- print(f"FastSAM input is not a PIL Image, type: {type(image_pil)}")
187
  if isinstance(image_pil, np.ndarray):
188
  try:
189
  image_pil = Image.fromarray(image_pil)
190
  print("Converted numpy input to PIL Image for FastSAM.")
191
  except Exception as e:
192
  print(f"Failed to convert numpy array to PIL Image: {e}")
193
- # Return None for image on error
194
- return None, "Error: Invalid image input format." # Return tuple for consistency
195
  else:
196
- # Return None for image on error
197
- return None, "Error: Please provide a valid image." # Return tuple
198
 
199
- # Ensure model is loaded
200
  if not load_fastsam_model() or not fastsam_lib_imported or FastSAMPrompt is None:
201
- # Return None for image on critical error
202
- return None, "Error: FastSAM not loaded or library unavailable."
203
 
204
  print(f"Running FastSAM 'segment everything' with conf={conf_threshold}, iou={iou_threshold}...")
205
- output_image = None # Initialize output image
206
- status_message = "Processing..." # Initial status
207
 
208
  try:
209
  # Ensure image is RGB
@@ -213,42 +243,31 @@ def run_fastsam_segmentation(image_pil: Image.Image, conf_threshold: float = 0.4
213
  else:
214
  image_pil_rgb = image_pil
215
 
216
- # Convert PIL Image to NumPy array (RGB)
217
  image_np_rgb = np.array(image_pil_rgb)
218
  print(f"Input image shape for FastSAM: {image_np_rgb.shape}")
219
 
220
  # Run FastSAM model
221
- # Make sure the arguments match what FastSAM expects
222
  everything_results = fastsam_model(
223
- image_np_rgb,
224
- device=DEVICE,
225
- retina_masks=True,
226
- imgsz=640, # Or another size FastSAM supports
227
- conf=conf_threshold,
228
- iou=iou_threshold,
229
- verbose=True # Keep verbose for debugging
230
  )
231
 
232
- # Check if results are valid before creating prompt
 
233
  if everything_results is None or not isinstance(everything_results, list) or len(everything_results) == 0:
234
- print("FastSAM model returned None or empty results.")
235
- # Return original image and status
236
- return image_pil, "FastSAM did not return valid results."
237
-
238
- # Results might be in a different format, inspect 'everything_results'
239
- print(f"Type of everything_results: {type(everything_results)}")
240
- print(f"Length of everything_results: {len(everything_results)}")
241
- if len(everything_results) > 0:
242
- print(f"Type of first element: {type(everything_results[0])}")
243
- # Try to access potential attributes like 'masks' if it's an object
244
- if hasattr(everything_results[0], 'masks') and everything_results[0].masks is not None:
245
- print(f"Masks found in results object, shape: {everything_results[0].masks.data.shape}")
246
- else:
247
- print("First result element does not have 'masks' attribute or it's None.")
248
 
 
 
 
 
 
249
 
250
- # Process results with FastSAMPrompt
251
- # Ensure FastSAMPrompt class is available
252
  if FastSAMPrompt is None:
253
  print("FastSAMPrompt class is not available.")
254
  return image_pil, "Error: FastSAMPrompt class not loaded."
@@ -256,89 +275,83 @@ def run_fastsam_segmentation(image_pil: Image.Image, conf_threshold: float = 0.4
256
  prompt_process = FastSAMPrompt(image_np_rgb, everything_results, device=DEVICE)
257
  ann = prompt_process.everything_prompt() # Get all annotations
258
 
259
- # Check annotation format - Adjust based on actual FastSAM output structure
260
- # Assuming 'ann' is a list and the first element is a dictionary containing masks
261
  masks = None
 
262
  if isinstance(ann, list) and len(ann) > 0 and isinstance(ann[0], dict) and 'masks' in ann[0]:
263
  mask_tensor = ann[0]['masks']
264
- if mask_tensor is not None and mask_tensor.numel() > 0: # Check if tensor is not None and not empty
265
  masks = mask_tensor.cpu().numpy()
266
  print(f"Found {len(masks)} masks with shape: {masks.shape}")
267
  else:
268
- print("Annotation 'masks' tensor is None or empty.")
269
  else:
270
  print(f"No masks found or annotation format unexpected. ann type: {type(ann)}")
271
- if isinstance(ann, list) and len(ann) > 0:
272
- print(f"First element of ann: {ann[0]}")
273
-
274
 
275
- # Prepare output image (start with original)
276
  output_image = image_pil.copy()
277
 
278
  # Draw masks if found
279
  if masks is not None and len(masks) > 0:
280
- # Ensure output_image is RGBA for compositing
281
  overlay = Image.new('RGBA', output_image.size, (0, 0, 0, 0))
282
  draw = ImageDraw.Draw(overlay)
283
-
284
  for i, mask in enumerate(masks):
285
- # Ensure mask is boolean/binary before converting
286
- binary_mask = (mask > 0) # Use threshold 0 for binary mask from FastSAM output
287
  mask_uint8 = binary_mask.astype(np.uint8) * 255
288
- if mask_uint8.max() == 0: # Skip empty masks
289
- # print(f"Skipping empty mask {i}")
290
- continue
291
 
292
- color = (random.randint(50, 255), random.randint(50, 255), random.randint(50, 255), 180) # RGBA color
293
  try:
294
- mask_image = Image.fromarray(mask_uint8, mode='L') # Grayscale mask
295
- # Draw the mask onto the overlay
296
  draw.bitmap((0, 0), mask_image, fill=color)
 
297
  except Exception as draw_err:
298
  print(f"Error drawing mask {i}: {draw_err}")
299
  traceback.print_exc()
300
- continue # Skip this mask
301
-
302
- # Composite the overlay onto the image
303
- try:
304
- output_image_rgba = output_image.convert('RGBA')
305
- output_image_composited = Image.alpha_composite(output_image_rgba, overlay)
306
- output_image = output_image_composited.convert('RGB') # Convert back to RGB for Gradio
307
- status_message = f"Segmentation complete. Found {len(masks)} masks."
308
- print("Mask drawing and compositing finished.")
309
- except Exception as comp_err:
310
- print(f"Error during alpha compositing: {comp_err}")
311
- traceback.print_exc()
312
- output_image = image_pil # Fallback to original image
313
- status_message = "Error during mask visualization."
314
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
315
  else:
316
  print("No masks detected or processed for 'segment everything' mode.")
317
  status_message = "No segments found or processed."
318
- output_image = image_pil # Return original image if no masks
319
 
320
  # Save for debugging before returning
321
  if output_image:
322
  try:
323
- debug_path = "debug_fastsam_everything_output.png"
324
- output_image.save(debug_path)
325
- print(f"Saved debug output to {debug_path}")
326
  except Exception as save_err:
327
  print(f"Failed to save debug image: {save_err}")
328
 
329
- return output_image, status_message # Return image and status message
330
 
331
  except Exception as e:
332
  print(f"Error during FastSAM 'everything' processing: {e}")
333
  traceback.print_exc()
334
- # Return original image and error message in case of failure
335
- return image_pil, f"Error during processing: {e}"
336
 
337
 
338
  def run_text_prompted_segmentation(image_pil: Image.Image, text_prompts: str, conf_threshold: float = 0.4, iou_threshold: float = 0.9):
339
- # Add input type check
 
 
340
  if not isinstance(image_pil, Image.Image):
341
- print(f"FastSAM Text input is not a PIL Image, type: {type(image_pil)}")
342
  if isinstance(image_pil, np.ndarray):
343
  try:
344
  image_pil = Image.fromarray(image_pil)
@@ -349,9 +362,9 @@ def run_text_prompted_segmentation(image_pil: Image.Image, text_prompts: str, co
349
  else:
350
  return None, "Error: Please provide a valid image."
351
 
352
- # Ensure model is loaded
353
  if not load_fastsam_model() or not fastsam_lib_imported or FastSAMPrompt is None:
354
- return image_pil, "Error: FastSAM Model not loaded or library unavailable." # Return original image on load fail
355
  if not text_prompts:
356
  return image_pil, "Please enter text prompts (e.g., 'person, dog')."
357
 
@@ -376,14 +389,13 @@ def run_text_prompted_segmentation(image_pil: Image.Image, text_prompts: str, co
376
 
377
  # Run FastSAM once to get all potential segments
378
  everything_results = fastsam_model(
379
- image_np_rgb, device=DEVICE, retina_masks=True, imgsz=640, # Use consistent args
380
- conf=conf_threshold, iou=iou_threshold, verbose=True
381
  )
382
 
383
- # Check results
384
  if everything_results is None or not isinstance(everything_results, list) or len(everything_results) == 0:
385
  print("FastSAM model returned None or empty results for text prompt base.")
386
- return image_pil, "FastSAM did not return base results."
387
 
388
  # Initialize FastSAMPrompt
389
  if FastSAMPrompt is None:
@@ -392,33 +404,30 @@ def run_text_prompted_segmentation(image_pil: Image.Image, text_prompts: str, co
392
  prompt_process = FastSAMPrompt(image_np_rgb, everything_results, device=DEVICE)
393
 
394
  all_matching_masks = []
395
- found_prompts_details = [] # Store details like 'prompt (count)'
396
 
397
  # Process each text prompt
398
  for text in prompts:
399
  print(f" Processing prompt: '{text}'")
400
- # Get annotation for the specific text prompt
401
  ann = prompt_process.text_prompt(text=text)
402
 
403
- # Check annotation format and extract masks
404
  current_masks = None
405
  num_found = 0
406
- # Adjust check based on actual structure of 'ann' for text_prompt
407
  if isinstance(ann, list) and len(ann) > 0 and isinstance(ann[0], dict) and 'masks' in ann[0]:
408
  mask_tensor = ann[0]['masks']
409
- if mask_tensor is not None and mask_tensor.numel() > 0:
410
  current_masks = mask_tensor.cpu().numpy()
411
  num_found = len(current_masks)
412
  print(f" Found {num_found} mask(s) for '{text}'. Shape: {current_masks.shape}")
413
- all_matching_masks.extend(current_masks) # Add found masks to the list
414
  else:
415
- print(f" Annotation 'masks' tensor is None or empty for '{text}'.")
416
  else:
417
  print(f" No masks found or annotation format unexpected for '{text}'. ann type: {type(ann)}")
418
- if isinstance(ann, list) and len(ann) > 0:
419
- print(f" First element of ann for '{text}': {ann[0]}")
420
 
421
- found_prompts_details.append(f"{text} ({num_found})") # Record count for status
422
 
423
  # Prepare output image
424
  output_image = image_pil.copy()
@@ -427,50 +436,49 @@ def run_text_prompted_segmentation(image_pil: Image.Image, text_prompts: str, co
427
  # Draw all collected masks if any were found
428
  if all_matching_masks:
429
  print(f"Total masks collected across all prompts: {len(all_matching_masks)}")
430
- # Stack masks if needed (optional, can draw one by one)
431
- # masks_np = np.stack(all_matching_masks, axis=0)
432
- # print(f"Total masks stacked shape: {masks_np.shape}")
433
-
434
  overlay = Image.new('RGBA', output_image.size, (0, 0, 0, 0))
435
  draw = ImageDraw.Draw(overlay)
 
436
 
437
- for i, mask in enumerate(all_matching_masks): # Iterate through collected masks
438
  binary_mask = (mask > 0)
439
  mask_uint8 = binary_mask.astype(np.uint8) * 255
440
- if mask_uint8.max() == 0:
441
- continue # Skip empty masks
442
 
443
- # Assign a unique color per mask or per prompt (using random here)
444
  color = (random.randint(50, 255), random.randint(50, 255), random.randint(50, 255), 180)
445
  try:
446
  mask_image = Image.fromarray(mask_uint8, mode='L')
447
  draw.bitmap((0, 0), mask_image, fill=color)
 
448
  except Exception as draw_err:
449
  print(f"Error drawing collected mask {i}: {draw_err}")
450
  traceback.print_exc()
451
- continue
452
 
453
- # Composite the overlay
454
- try:
455
- output_image_rgba = output_image.convert('RGBA')
456
- output_image_composited = Image.alpha_composite(output_image_rgba, overlay)
457
- output_image = output_image_composited.convert('RGB')
458
- print("Text prompt mask drawing and compositing finished.")
459
- except Exception as comp_err:
460
- print(f"Error during alpha compositing for text prompts: {comp_err}")
461
- traceback.print_exc()
462
- output_image = image_pil # Fallback
463
- status_message += " (Error during visualization)"
 
 
 
 
 
 
464
  else:
465
  print("No matching masks found for any text prompt.")
466
- # status_message is already set
467
 
468
  # Save for debugging
469
  if output_image:
470
  try:
471
- debug_path = "debug_fastsam_text_output.png"
472
- output_image.save(debug_path)
473
- print(f"Saved debug output to {debug_path}")
474
  except Exception as save_err:
475
  print(f"Failed to save debug image: {save_err}")
476
 
@@ -479,76 +487,180 @@ def run_text_prompted_segmentation(image_pil: Image.Image, text_prompts: str, co
479
  except Exception as e:
480
  print(f"Error during FastSAM text-prompted processing: {e}")
481
  traceback.print_exc()
482
- # Return original image and error message
483
  return image_pil, f"Error during processing: {e}"
484
 
485
- # --- Gradio Interface ---
486
-
487
  print("Attempting to preload models...")
488
- load_clip_model() # Preload CLIP
489
- load_fastsam_model() # Preload FastSAM
490
- print("Preloading finished (check logs above for errors).")
491
 
492
 
493
  # --- Gradio Interface Definition ---
494
- # (Your Gradio Blocks code remains largely the same, but ensure the outputs match the function returns)
495
-
496
- # --- Gradio Interface ---
497
- # ... (imports and functions) ...
498
-
499
- with gr.Blocks(theme=gr.themes.Soft()) as demo: # START of the block
500
  gr.Markdown("# CLIP & FastSAM Demo")
501
- # ... other UI elements ...
 
 
 
 
502
 
503
  with gr.Tabs():
 
504
  with gr.TabItem("CLIP Zero-Shot Classification"):
505
- gr.Markdown("Upload an image and provide comma-separated labels...")
506
  with gr.Row():
507
  with gr.Column(scale=1):
 
508
  clip_input_image = gr.Image(type="pil", label="Input Image")
509
- clip_text_labels = gr.Textbox(label="Comma-Separated Labels", placeholder="e.g., astronaut, moon") # DEFINE the button
510
  clip_button = gr.Button("Run CLIP Classification", variant="primary")
511
  with gr.Column(scale=1):
512
  clip_output_label = gr.Label(label="Classification Probabilities")
513
- clip_output_image_display = gr.Image(type="pil", label="Input Image Preview")
514
 
515
- # ATTACH the click handler *inside* the block, after the button is defined
516
  clip_button.click(
517
  run_clip_zero_shot,
518
  inputs=[clip_input_image, clip_text_labels],
519
  outputs=[clip_output_label, clip_output_image_display]
520
  )
521
- # ... CLIP examples ...
522
 
 
 
 
 
 
 
 
 
 
 
 
 
 
523
  with gr.TabItem("FastSAM Segment Everything"):
524
- # ... FastSAM Everything UI elements ...
525
- fastsam_button_all = gr.Button(...) # Define button
 
 
 
 
 
 
 
 
 
 
526
 
527
- # Attach click handler *inside* the block
528
  fastsam_button_all.click(
529
  run_fastsam_segmentation,
530
- inputs=[fastsam_input_image_all, fastsam_conf_all, fastsam_iou_all], # Correct list of inputs
531
  outputs=[fastsam_output_image_all, fastsam_status_all]
532
  )
533
- # ... FastSAM Everything examples ...
534
 
 
 
 
 
 
 
 
 
 
 
 
 
 
535
  with gr.TabItem("Text-Prompted Segmentation"):
536
- # ... Text-Prompted UI elements ...
537
- prompt_button = gr.Button(...) # Define button
 
 
 
 
 
 
 
 
 
 
 
538
 
539
- # Attach click handler *inside* the block
540
  prompt_button.click(
541
  run_text_prompted_segmentation,
542
- inputs=[...],
543
- outputs=[...]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
544
  )
545
- # ... Text-Prompted examples ...
546
 
547
- # The `with` block ends here.
548
- # --- Example File Download (This is correctly outside the block) ---
549
- # ... download logic ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
550
 
551
- # --- Launch App (This is correctly outside the block) ---
552
  if __name__ == "__main__":
 
553
  print("Launching Gradio Demo...")
554
- demo.launch(debug=True)
 
 
 
 
7
  import os
8
  import wget
9
  import traceback
10
+ import sys # Import sys for checking modules
11
 
12
  # --- Configuration & Model Loading ---
13
 
14
  # Device Selection with fallback
15
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
16
  print(f"Using device: {DEVICE}")
17
 
18
  # --- CLIP Setup ---
 
29
  print("CLIP processor loaded.")
30
  except Exception as e:
31
  print(f"Error loading CLIP processor: {e}")
32
+ traceback.print_exc()
33
  return False
34
  if clip_model is None:
35
  try:
 
38
  print(f"CLIP model loaded to {DEVICE}.")
39
  except Exception as e:
40
  print(f"Error loading CLIP model: {e}")
41
+ traceback.print_exc()
42
  return False
43
  return True
44
 
 
52
  FastSAMPrompt = None # Define placeholders
53
 
54
  def check_and_import_fastsam():
55
+ global fastsam_lib_imported, FastSAM, FastSAMPrompt
56
  if not fastsam_lib_imported:
57
+ # Check if ultralytics is installed first, as it's a dependency
58
+ if 'ultralytics' not in sys.modules:
59
+ try:
60
+ # Try importing to trigger potential error if not installed
61
+ import ultralytics
62
+ print("Found 'ultralytics' library.")
63
+ except ImportError:
64
+ print("\n--- ERROR ---")
65
+ print("The 'ultralytics' library (required by FastSAM) is not installed.")
66
+ print("Please install it first: pip install ultralytics")
67
+ print("---------------\n")
68
+ return False # Cannot proceed without ultralytics
69
+
70
+ # Now try importing fastsam
71
  try:
72
+ # Use temporary names to avoid conflict if they exist globally somehow
73
+ from fastsam import FastSAM as FastSAM_lib, FastSAMPrompt as FastSAMPrompt_lib
74
+ FastSAM = FastSAM_lib # Assign to global placeholder
75
+ FastSAMPrompt = FastSAMPrompt_lib # Assign to global placeholder
76
  fastsam_lib_imported = True
77
  print("fastsam library imported successfully.")
78
  except ImportError as e:
79
+ print("\n--- ERROR ---")
80
+ print("The 'fastsam' library was not found or could not be imported.")
81
+ print("Please ensure it is installed correctly:")
82
+ print(" pip install git+https://github.com/CASIA-IVA-Lab/FastSAM.git")
83
+ print(f"(ImportError: {e})")
84
+ print("Also ensure 'ultralytics' is installed: pip install ultralytics")
85
+ print("---------------\n")
86
  fastsam_lib_imported = False
87
  except Exception as e:
88
  print(f"Unexpected error during fastsam import: {e}")
 
93
  def download_fastsam_weights(retries=3):
94
  if not os.path.exists(FASTSAM_CHECKPOINT):
95
  print(f"Downloading FastSAM weights: {FASTSAM_CHECKPOINT} from {FASTSAM_CHECKPOINT_URL}...")
96
+ # Ensure the directory exists if FASTSAM_CHECKPOINT includes a path
97
+ checkpoint_dir = os.path.dirname(FASTSAM_CHECKPOINT)
98
+ if checkpoint_dir and not os.path.exists(checkpoint_dir):
99
+ try:
100
+ os.makedirs(checkpoint_dir)
101
+ print(f"Created directory for weights: {checkpoint_dir}")
102
+ except OSError as e:
103
+ print(f"Error creating directory {checkpoint_dir}: {e}")
104
+ return False
105
+
106
  for attempt in range(retries):
107
  try:
 
 
108
  wget.download(FASTSAM_CHECKPOINT_URL, FASTSAM_CHECKPOINT)
109
+ print("FastSAM weights downloaded successfully.")
110
  return True # Return True on successful download
111
  except Exception as e:
112
  print(f"Attempt {attempt + 1}/{retries} failed to download FastSAM weights: {e}")
 
118
  if attempt + 1 == retries:
119
  print("Failed to download weights after all attempts.")
120
  return False
121
+ return False # Should not be reached if loop completes correctly
122
  else:
123
+ print(f"FastSAM weights file '{FASTSAM_CHECKPOINT}' already exists.")
124
  return True # Weights exist
125
 
126
  def load_fastsam_model():
127
  global fastsam_model
128
  if fastsam_model is None:
129
+ print("Attempting to load FastSAM model...")
130
  if not check_and_import_fastsam():
131
  print("Cannot load FastSAM model due to library import failure.")
132
  return False
133
+ if not download_fastsam_weights():
134
+ print("Cannot load FastSAM model because weights are missing or download failed.")
135
+ return False
136
+
137
+ # Ensure FastSAM class is available (double check after import attempt)
138
+ if FastSAM is None:
139
+ print("FastSAM class reference is None, cannot instantiate model.")
140
+ return False
141
+
142
+ try:
143
+ print(f"Loading FastSAM model from checkpoint: {FASTSAM_CHECKPOINT}...")
144
+ # Instantiate the imported FastSAM class
145
+ fastsam_model = FastSAM(FASTSAM_CHECKPOINT)
146
+ # Note: FastSAM typically handles device placement internally based on constructor args or method calls.
147
+ # If you face device issues, check FastSAM's documentation for explicit device moving.
148
+ # Example: Some models might need fastsam_model.model.to(DEVICE) - check structure.
149
+ print("FastSAM model loaded successfully.")
150
+ return True
151
+ except Exception as e:
152
+ print(f"Error loading FastSAM model weights or initializing: {e}")
153
+ traceback.print_exc()
154
+ fastsam_model = None # Ensure model is None if loading failed
155
  return False
156
  # Model already loaded
157
+ # print("FastSAM model already loaded.") # Optional: uncomment for debugging reuse
158
  return True
159
 
160
  # --- Processing Functions ---
161
 
162
  def run_clip_zero_shot(image: Image.Image, text_labels: str):
163
+ # Input validation
164
+ if image is None:
165
+ return "Error: Please upload an image.", None # Return None for image component
166
  if not isinstance(image, Image.Image):
167
+ print(f"CLIP input is not a PIL Image, type: {type(image)}. Attempting conversion.")
 
168
  if isinstance(image, np.ndarray):
169
  try:
170
  image = Image.fromarray(image)
 
175
  else:
176
  return "Error: Please provide a valid image.", None
177
 
178
+ # Model loading check
179
  if clip_model is None or clip_processor is None:
180
  if not load_clip_model():
 
181
  return "Error: CLIP Model could not be loaded.", None
182
+
183
+ # Label check
184
  if not text_labels:
185
+ return {}, image # Return empty dict and original image if no labels
 
186
 
187
  labels = [label.strip() for label in text_labels.split(',') if label.strip()]
188
  if not labels:
189
+ return {}, image # Return empty dict and original image if no valid labels
 
190
 
191
  print(f"Running CLIP zero-shot classification with labels: {labels}")
192
  try:
 
198
  inputs = clip_processor(text=labels, images=image, return_tensors="pt", padding=True).to(DEVICE)
199
  with torch.no_grad():
200
  outputs = clip_model(**inputs)
201
+ logits_per_image = outputs.logits_per_image
202
+ probs = logits_per_image.softmax(dim=1)
 
203
 
 
204
  confidences = {labels[i]: float(probs[0, i].item()) for i in range(len(labels))}
205
  print(f"CLIP Confidences: {confidences}")
 
206
  return confidences, image
207
+
208
  except Exception as e:
209
  print(f"Error during CLIP processing: {e}")
210
  traceback.print_exc()
 
211
  return f"Error during CLIP processing: {e}", None
212
 
213
 
214
  def run_fastsam_segmentation(image_pil: Image.Image, conf_threshold: float = 0.4, iou_threshold: float = 0.9):
215
+ # Input validation
216
+ if image_pil is None:
217
+ return None, "Error: Please upload an image."
218
  if not isinstance(image_pil, Image.Image):
219
+ print(f"FastSAM input is not a PIL Image, type: {type(image_pil)}. Attempting conversion.")
220
  if isinstance(image_pil, np.ndarray):
221
  try:
222
  image_pil = Image.fromarray(image_pil)
223
  print("Converted numpy input to PIL Image for FastSAM.")
224
  except Exception as e:
225
  print(f"Failed to convert numpy array to PIL Image: {e}")
226
+ return None, "Error: Invalid image input format."
 
227
  else:
228
+ return None, "Error: Please provide a valid image."
 
229
 
230
+ # Model loading check
231
  if not load_fastsam_model() or not fastsam_lib_imported or FastSAMPrompt is None:
232
+ return image_pil, "Error: FastSAM model/library not ready. Check logs." # Return original image if model failed
 
233
 
234
  print(f"Running FastSAM 'segment everything' with conf={conf_threshold}, iou={iou_threshold}...")
235
+ output_image = None
236
+ status_message = "Processing..."
237
 
238
  try:
239
  # Ensure image is RGB
 
243
  else:
244
  image_pil_rgb = image_pil
245
 
 
246
  image_np_rgb = np.array(image_pil_rgb)
247
  print(f"Input image shape for FastSAM: {image_np_rgb.shape}")
248
 
249
  # Run FastSAM model
 
250
  everything_results = fastsam_model(
251
+ image_np_rgb, device=DEVICE, retina_masks=True, imgsz=640, # Adjust imgsz if needed
252
+ conf=conf_threshold, iou=iou_threshold, verbose=False # Set verbose=False for cleaner logs unless debugging
 
 
 
 
 
253
  )
254
 
255
+ # Check results type and content (FastSAM results format might vary)
256
+ # Typically a list of result objects, or similar structure
257
  if everything_results is None or not isinstance(everything_results, list) or len(everything_results) == 0:
258
+ print("FastSAM model returned None or empty results list.")
259
+ return image_pil, "FastSAM processing returned no results."
260
+
261
+ # Assuming the first result object contains the relevant data
262
+ first_result = everything_results[0]
 
 
 
 
 
 
 
 
 
263
 
264
+ # --- IMPORTANT: Inspect the 'first_result' object ---
265
+ # Use print(dir(first_result)), print(type(first_result)) etc. if unsure
266
+ # Common attributes might be .masks, .boxes, .names
267
+ # print(f"Type of first_result: {type(first_result)}")
268
+ # print(f"Attributes of first_result: {dir(first_result)}")
269
 
270
+ # Initialize FastSAMPrompt
 
271
  if FastSAMPrompt is None:
272
  print("FastSAMPrompt class is not available.")
273
  return image_pil, "Error: FastSAMPrompt class not loaded."
 
275
  prompt_process = FastSAMPrompt(image_np_rgb, everything_results, device=DEVICE)
276
  ann = prompt_process.everything_prompt() # Get all annotations
277
 
278
+ # Check annotation format - Adapt based on actual FastSAM/FastSAMPrompt output
 
279
  masks = None
280
+ # Expected format: list containing a dict with 'masks' tensor
281
  if isinstance(ann, list) and len(ann) > 0 and isinstance(ann[0], dict) and 'masks' in ann[0]:
282
  mask_tensor = ann[0]['masks']
283
+ if mask_tensor is not None and isinstance(mask_tensor, torch.Tensor) and mask_tensor.numel() > 0:
284
  masks = mask_tensor.cpu().numpy()
285
  print(f"Found {len(masks)} masks with shape: {masks.shape}")
286
  else:
287
+ print("Annotation 'masks' tensor is None, not a Tensor, or empty.")
288
  else:
289
  print(f"No masks found or annotation format unexpected. ann type: {type(ann)}")
290
+ if isinstance(ann, list) and len(ann) > 0: print(f"First element of ann: {ann[0]}")
 
 
291
 
292
+ # Prepare output image
293
  output_image = image_pil.copy()
294
 
295
  # Draw masks if found
296
  if masks is not None and len(masks) > 0:
 
297
  overlay = Image.new('RGBA', output_image.size, (0, 0, 0, 0))
298
  draw = ImageDraw.Draw(overlay)
299
+ valid_masks_drawn = 0
300
  for i, mask in enumerate(masks):
301
+ binary_mask = (mask > 0) # Use threshold 0 for binary mask
 
302
  mask_uint8 = binary_mask.astype(np.uint8) * 255
303
+ if mask_uint8.max() == 0: continue # Skip empty masks
 
 
304
 
305
+ color = (random.randint(50, 255), random.randint(50, 255), random.randint(50, 255), 180)
306
  try:
307
+ mask_image = Image.fromarray(mask_uint8, mode='L')
 
308
  draw.bitmap((0, 0), mask_image, fill=color)
309
+ valid_masks_drawn += 1
310
  except Exception as draw_err:
311
  print(f"Error drawing mask {i}: {draw_err}")
312
  traceback.print_exc()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
313
 
314
+ if valid_masks_drawn > 0:
315
+ try:
316
+ output_image_rgba = output_image.convert('RGBA')
317
+ output_image_composited = Image.alpha_composite(output_image_rgba, overlay)
318
+ output_image = output_image_composited.convert('RGB')
319
+ status_message = f"Segmentation complete. Found and drew {valid_masks_drawn} masks."
320
+ print("Mask drawing and compositing finished.")
321
+ except Exception as comp_err:
322
+ print(f"Error during alpha compositing: {comp_err}")
323
+ traceback.print_exc()
324
+ output_image = image_pil # Fallback
325
+ status_message = f"Found {valid_masks_drawn} masks, but error during visualization."
326
+ else:
327
+ status_message = f"Found {len(masks)} masks initially, but none were valid for drawing."
328
+ output_image = image_pil # Return original if no valid masks drawn
329
  else:
330
  print("No masks detected or processed for 'segment everything' mode.")
331
  status_message = "No segments found or processed."
332
+ output_image = image_pil # Return original image
333
 
334
  # Save for debugging before returning
335
  if output_image:
336
  try:
337
+ output_image.save("debug_fastsam_everything_output.png")
 
 
338
  except Exception as save_err:
339
  print(f"Failed to save debug image: {save_err}")
340
 
341
+ return output_image, status_message
342
 
343
  except Exception as e:
344
  print(f"Error during FastSAM 'everything' processing: {e}")
345
  traceback.print_exc()
346
+ return image_pil, f"Error during processing: {e}" # Return original image and error
 
347
 
348
 
349
  def run_text_prompted_segmentation(image_pil: Image.Image, text_prompts: str, conf_threshold: float = 0.4, iou_threshold: float = 0.9):
350
+ # Input validation
351
+ if image_pil is None:
352
+ return None, "Error: Please upload an image."
353
  if not isinstance(image_pil, Image.Image):
354
+ print(f"FastSAM Text input is not a PIL Image, type: {type(image_pil)}. Attempting conversion.")
355
  if isinstance(image_pil, np.ndarray):
356
  try:
357
  image_pil = Image.fromarray(image_pil)
 
362
  else:
363
  return None, "Error: Please provide a valid image."
364
 
365
+ # Model loading check
366
  if not load_fastsam_model() or not fastsam_lib_imported or FastSAMPrompt is None:
367
+ return image_pil, "Error: FastSAM model/library not ready. Check logs."
368
  if not text_prompts:
369
  return image_pil, "Please enter text prompts (e.g., 'person, dog')."
370
 
 
389
 
390
  # Run FastSAM once to get all potential segments
391
  everything_results = fastsam_model(
392
+ image_np_rgb, device=DEVICE, retina_masks=True, imgsz=640,
393
+ conf=conf_threshold, iou=iou_threshold, verbose=False # Set verbose=False usually
394
  )
395
 
 
396
  if everything_results is None or not isinstance(everything_results, list) or len(everything_results) == 0:
397
  print("FastSAM model returned None or empty results for text prompt base.")
398
+ return image_pil, "FastSAM did not return base results needed for text prompting."
399
 
400
  # Initialize FastSAMPrompt
401
  if FastSAMPrompt is None:
 
404
  prompt_process = FastSAMPrompt(image_np_rgb, everything_results, device=DEVICE)
405
 
406
  all_matching_masks = []
407
+ found_prompts_details = []
408
 
409
  # Process each text prompt
410
  for text in prompts:
411
  print(f" Processing prompt: '{text}'")
 
412
  ann = prompt_process.text_prompt(text=text)
413
 
 
414
  current_masks = None
415
  num_found = 0
416
+ # Check annotation format - adapt based on text_prompt output structure
417
  if isinstance(ann, list) and len(ann) > 0 and isinstance(ann[0], dict) and 'masks' in ann[0]:
418
  mask_tensor = ann[0]['masks']
419
+ if mask_tensor is not None and isinstance(mask_tensor, torch.Tensor) and mask_tensor.numel() > 0:
420
  current_masks = mask_tensor.cpu().numpy()
421
  num_found = len(current_masks)
422
  print(f" Found {num_found} mask(s) for '{text}'. Shape: {current_masks.shape}")
423
+ all_matching_masks.extend(current_masks) # Add found masks
424
  else:
425
+ print(f" Annotation 'masks' tensor is None, not a Tensor, or empty for '{text}'.")
426
  else:
427
  print(f" No masks found or annotation format unexpected for '{text}'. ann type: {type(ann)}")
428
+ if isinstance(ann, list) and len(ann) > 0: print(f" First element of ann for '{text}': {ann[0]}")
 
429
 
430
+ found_prompts_details.append(f"{text} ({num_found})")
431
 
432
  # Prepare output image
433
  output_image = image_pil.copy()
 
436
  # Draw all collected masks if any were found
437
  if all_matching_masks:
438
  print(f"Total masks collected across all prompts: {len(all_matching_masks)}")
 
 
 
 
439
  overlay = Image.new('RGBA', output_image.size, (0, 0, 0, 0))
440
  draw = ImageDraw.Draw(overlay)
441
+ valid_masks_drawn = 0
442
 
443
+ for i, mask in enumerate(all_matching_masks):
444
  binary_mask = (mask > 0)
445
  mask_uint8 = binary_mask.astype(np.uint8) * 255
446
+ if mask_uint8.max() == 0: continue
 
447
 
 
448
  color = (random.randint(50, 255), random.randint(50, 255), random.randint(50, 255), 180)
449
  try:
450
  mask_image = Image.fromarray(mask_uint8, mode='L')
451
  draw.bitmap((0, 0), mask_image, fill=color)
452
+ valid_masks_drawn += 1
453
  except Exception as draw_err:
454
  print(f"Error drawing collected mask {i}: {draw_err}")
455
  traceback.print_exc()
 
456
 
457
+ if valid_masks_drawn > 0:
458
+ try:
459
+ output_image_rgba = output_image.convert('RGBA')
460
+ output_image_composited = Image.alpha_composite(output_image_rgba, overlay)
461
+ output_image = output_image_composited.convert('RGB')
462
+ print("Text prompt mask drawing and compositing finished.")
463
+ # Append drawing status if needed
464
+ if valid_masks_drawn < len(all_matching_masks):
465
+ status_message += f" (Drew {valid_masks_drawn}/{len(all_matching_masks)} found masks)"
466
+ except Exception as comp_err:
467
+ print(f"Error during alpha compositing for text prompts: {comp_err}")
468
+ traceback.print_exc()
469
+ output_image = image_pil # Fallback
470
+ status_message += " (Error during visualization)"
471
+ else:
472
+ output_image = image_pil # Return original if no masks drawn
473
+ status_message += " (No valid masks to draw)"
474
  else:
475
  print("No matching masks found for any text prompt.")
476
+ output_image = image_pil # Return original image
477
 
478
  # Save for debugging
479
  if output_image:
480
  try:
481
+ output_image.save("debug_fastsam_text_output.png")
 
 
482
  except Exception as save_err:
483
  print(f"Failed to save debug image: {save_err}")
484
 
 
487
  except Exception as e:
488
  print(f"Error during FastSAM text-prompted processing: {e}")
489
  traceback.print_exc()
 
490
  return image_pil, f"Error during processing: {e}"
491
 
492
+ # --- Preload Models ---
 
493
  print("Attempting to preload models...")
494
+ load_clip_model()
495
+ load_fastsam_model() # Try to load FastSAM eagerly
496
+ print("Preloading finished (check logs above for success/errors).")
497
 
498
 
499
  # --- Gradio Interface Definition ---
500
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
 
 
 
 
 
501
  gr.Markdown("# CLIP & FastSAM Demo")
502
+ gr.Markdown("Explore Zero-Shot Classification, 'Segment Everything', and Text-Prompted Segmentation.")
503
+ gr.Markdown("---")
504
+ gr.Markdown("**NOTE:** Ensure required libraries are installed: `pip install --upgrade gradio torch transformers Pillow numpy wget ultralytics` and `pip install git+https://github.com/CASIA-IVA-Lab/FastSAM.git`")
505
+ gr.Markdown("---")
506
+
507
 
508
  with gr.Tabs():
509
+ # --- CLIP Tab ---
510
  with gr.TabItem("CLIP Zero-Shot Classification"):
511
+ gr.Markdown("Upload an image and provide comma-separated labels (e.g., 'cat, dog, car').")
512
  with gr.Row():
513
  with gr.Column(scale=1):
514
+ # Define UI elements first
515
  clip_input_image = gr.Image(type="pil", label="Input Image")
516
+ clip_text_labels = gr.Textbox(label="Comma-Separated Labels", placeholder="e.g., astronaut, moon")
517
  clip_button = gr.Button("Run CLIP Classification", variant="primary")
518
  with gr.Column(scale=1):
519
  clip_output_label = gr.Label(label="Classification Probabilities")
520
+ clip_output_image_display = gr.Image(type="pil", label="Input Image Preview", interactive=False)
521
 
522
+ # Define the click handler AFTER elements are defined
523
  clip_button.click(
524
  run_clip_zero_shot,
525
  inputs=[clip_input_image, clip_text_labels],
526
  outputs=[clip_output_label, clip_output_image_display]
527
  )
 
528
 
529
+ gr.Examples(
530
+ examples=[
531
+ ["examples/astronaut.jpg", "astronaut, moon, rover"],
532
+ ["examples/dog_bike.jpg", "dog, bicycle, person"],
533
+ ["examples/clip_logo.png", "logo, text, graphics"],
534
+ ],
535
+ inputs=[clip_input_image, clip_text_labels],
536
+ outputs=[clip_output_label, clip_output_image_display],
537
+ fn=run_clip_zero_shot,
538
+ cache_examples=False, # Keep False during debugging
539
+ )
540
+
541
+ # --- FastSAM Everything Tab ---
542
  with gr.TabItem("FastSAM Segment Everything"):
543
+ gr.Markdown("Upload an image to segment all objects/regions.")
544
+ with gr.Row():
545
+ with gr.Column(scale=1):
546
+ # Define UI elements first
547
+ fastsam_input_image_all = gr.Image(type="pil", label="Input Image")
548
+ with gr.Row():
549
+ fastsam_conf_all = gr.Slider(minimum=0.1, maximum=1.0, value=0.4, step=0.05, label="Confidence Threshold")
550
+ fastsam_iou_all = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="IoU Threshold")
551
+ fastsam_button_all = gr.Button("Run FastSAM Segmentation", variant="primary")
552
+ with gr.Column(scale=1):
553
+ fastsam_output_image_all = gr.Image(type="pil", label="Segmented Image", interactive=False)
554
+ fastsam_status_all = gr.Textbox(label="Status", interactive=False)
555
 
556
+ # Define the click handler AFTER elements are defined
557
  fastsam_button_all.click(
558
  run_fastsam_segmentation,
559
+ inputs=[fastsam_input_image_all, fastsam_conf_all, fastsam_iou_all], # Correct inputs list
560
  outputs=[fastsam_output_image_all, fastsam_status_all]
561
  )
 
562
 
563
+ gr.Examples(
564
+ examples=[
565
+ ["examples/dogs.jpg", 0.4, 0.9],
566
+ ["examples/fruits.jpg", 0.5, 0.8],
567
+ ["examples/lion.jpg", 0.45, 0.9],
568
+ ],
569
+ inputs=[fastsam_input_image_all, fastsam_conf_all, fastsam_iou_all],
570
+ outputs=[fastsam_output_image_all, fastsam_status_all],
571
+ fn=run_fastsam_segmentation,
572
+ cache_examples=False,
573
+ )
574
+
575
+ # --- Text-Prompted Segmentation Tab ---
576
  with gr.TabItem("Text-Prompted Segmentation"):
577
+ gr.Markdown("Upload an image and provide comma-separated prompts (e.g., 'person, dog').")
578
+ with gr.Row():
579
+ with gr.Column(scale=1):
580
+ # Define UI elements first
581
+ prompt_input_image = gr.Image(type="pil", label="Input Image")
582
+ prompt_text_input = gr.Textbox(label="Comma-Separated Text Prompts", placeholder="e.g., glasses, watch")
583
+ with gr.Row():
584
+ prompt_conf = gr.Slider(minimum=0.1, maximum=1.0, value=0.4, step=0.05, label="Confidence Threshold")
585
+ prompt_iou = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="IoU Threshold")
586
+ prompt_button = gr.Button("Segment by Text", variant="primary")
587
+ with gr.Column(scale=1):
588
+ prompt_output_image = gr.Image(type="pil", label="Text-Prompted Segmentation", interactive=False)
589
+ prompt_status_message = gr.Textbox(label="Status", interactive=False)
590
 
591
+ # Define the click handler AFTER elements are defined
592
  prompt_button.click(
593
  run_text_prompted_segmentation,
594
+ inputs=[prompt_input_image, prompt_text_input, prompt_conf, prompt_iou], # Correct inputs list
595
+ outputs=[prompt_output_image, prompt_status_message]
596
+ )
597
+
598
+ gr.Examples(
599
+ examples=[
600
+ ["examples/dog_bike.jpg", "person, bicycle", 0.4, 0.9],
601
+ ["examples/astronaut.jpg", "person, helmet", 0.35, 0.9],
602
+ ["examples/dogs.jpg", "dog", 0.4, 0.9],
603
+ ["examples/fruits.jpg", "banana, apple", 0.5, 0.8],
604
+ ["examples/teacher.jpg", "person, glasses", 0.4, 0.9],
605
+ ],
606
+ inputs=[prompt_input_image, prompt_text_input, prompt_conf, prompt_iou],
607
+ outputs=[prompt_output_image, prompt_status_message],
608
+ fn=run_text_prompted_segmentation,
609
+ cache_examples=False,
610
  )
 
611
 
612
+ # --- Example File Download ---
613
+ # (This logic should be outside the `with gr.Blocks...` block)
614
+ if not os.path.exists("examples"):
615
+ try:
616
+ os.makedirs("examples")
617
+ print("Created 'examples' directory.")
618
+ except OSError as e:
619
+ print(f"Error creating 'examples' directory: {e}")
620
+
621
+ example_files = {
622
+ "astronaut.jpg": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/d1/Astronaut_-_St._Jean_Bay.jpg/640px-Astronaut_-_St._Jean_Bay.jpg",
623
+ "dog_bike.jpg": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/gradio/outputs_multimodal.jpg",
624
+ "clip_logo.png": "https://raw.githubusercontent.com/openai/CLIP/main/CLIP.png",
625
+ "dogs.jpg": "https://raw.githubusercontent.com/ultralytics/assets/main/im/image8.jpg",
626
+ "fruits.jpg": "https://raw.githubusercontent.com/ultralytics/assets/main/im/image9.jpg",
627
+ "lion.jpg": "https://huggingface.co/spaces/gradio/image-segmentation/resolve/main/images/lion.jpg",
628
+ "teacher.jpg": "https://images.pexels.com/photos/848117/pexels-photo-848117.jpeg?auto=compress&cs=tinysrgb&w=600"
629
+ }
630
+
631
+ def download_example_file(filename, url, retries=3):
632
+ filepath = os.path.join("examples", filename)
633
+ if not os.path.exists(filepath):
634
+ print(f"Attempting to download {filename}...")
635
+ for attempt in range(retries):
636
+ try:
637
+ wget.download(url, filepath)
638
+ print(f"Downloaded {filename} successfully.")
639
+ return # Exit function on success
640
+ except Exception as e:
641
+ print(f"Download attempt {attempt + 1}/{retries} for {filename} failed: {e}")
642
+ if os.path.exists(filepath): # Clean up partial download
643
+ try: os.remove(filepath)
644
+ except OSError: pass
645
+ if attempt + 1 == retries:
646
+ print(f"Failed to download {filename} after {retries} attempts.")
647
+ # else: # Optional: uncomment if you want confirmation for existing files
648
+ # print(f"Example file {filename} already exists.")
649
+
650
+ # Trigger downloads if directory exists
651
+ if os.path.exists("examples"):
652
+ for filename, url in example_files.items():
653
+ download_example_file(filename, url)
654
+ print("Example file check/download process complete.")
655
+ else:
656
+ print("Skipping example download because 'examples' directory could not be created.")
657
+
658
 
659
+ # --- Launch App ---
660
  if __name__ == "__main__":
661
+ print("-----------------------------------------")
662
  print("Launching Gradio Demo...")
663
+ print("Ensure FastSAM model and weights are correctly loaded (check logs above).")
664
+ print("If FastSAM fails, check installation: pip install ultralytics && pip install git+https://github.com/CASIA-IVA-Lab/FastSAM.git")
665
+ print("-----------------------------------------")
666
+ demo.launch(debug=True) # Keep debug=True for detailed Gradio errors