sagar007 commited on
Commit
03c5849
·
verified ·
1 Parent(s): 23fa119

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +151 -189
app.py CHANGED
@@ -1,22 +1,20 @@
1
  import gradio as gr
2
  import torch
3
- from transformers import AutoProcessor, AutoModel # Keep CLIP for potential future use or if FastSAM's text prompt isn't enough
4
  from PIL import Image, ImageDraw, ImageFont
5
  import numpy as np
6
  import random
7
  import os
8
- import wget # To download weights
9
- import traceback # For detailed error printing
10
 
11
  # --- Configuration & Model Loading ---
12
 
13
- # Device Selection
14
- DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
15
- # Force CPU if CUDA fails or isn't desired (sometimes needed on Spaces free tier)
16
- # DEVICE = "cpu"
17
  print(f"Using device: {DEVICE}")
18
 
19
- # --- CLIP Setup (Kept in case needed, but FastSAM's method is primary now) ---
20
  CLIP_MODEL_ID = "openai/clip-vit-base-patch32"
21
  clip_processor = None
22
  clip_model = None
@@ -30,7 +28,7 @@ def load_clip_model():
30
  print("CLIP processor loaded.")
31
  except Exception as e:
32
  print(f"Error loading CLIP processor: {e}")
33
- return False # Indicate failure
34
  if clip_model is None:
35
  try:
36
  print(f"Loading CLIP model: {CLIP_MODEL_ID}...")
@@ -38,255 +36,218 @@ def load_clip_model():
38
  print(f"CLIP model loaded to {DEVICE}.")
39
  except Exception as e:
40
  print(f"Error loading CLIP model: {e}")
41
- return False # Indicate failure
42
- return True # Indicate success
43
-
44
 
45
  # --- FastSAM Setup ---
46
  FASTSAM_CHECKPOINT = "FastSAM-s.pt"
47
  FASTSAM_CHECKPOINT_URL = f"https://huggingface.co/CASIA-IVA-Lab/FastSAM-s/resolve/main/{FASTSAM_CHECKPOINT}"
48
 
49
  fastsam_model = None
50
- fastsam_lib_imported = False # Flag to check if import worked
51
 
52
  def check_and_import_fastsam():
53
  global fastsam_lib_imported
54
  if not fastsam_lib_imported:
55
  try:
56
  from fastsam import FastSAM, FastSAMPrompt
57
- globals()['FastSAM'] = FastSAM # Make classes available globally
58
  globals()['FastSAMPrompt'] = FastSAMPrompt
59
  fastsam_lib_imported = True
60
  print("fastsam library imported successfully.")
61
- except ImportError:
62
- print("Error: 'fastsam' library not found or import failed.")
63
- print("Please ensure 'fastsam' is installed correctly (pip install fastsam).")
64
  fastsam_lib_imported = False
65
  except Exception as e:
66
- print(f"An unexpected error occurred during fastsam import: {e}")
67
  traceback.print_exc()
68
  fastsam_lib_imported = False
69
  return fastsam_lib_imported
70
 
71
-
72
- def download_fastsam_weights():
73
  if not os.path.exists(FASTSAM_CHECKPOINT):
74
  print(f"Downloading FastSAM weights: {FASTSAM_CHECKPOINT} from {FASTSAM_CHECKPOINT_URL}...")
75
- try:
76
- wget.download(FASTSAM_CHECKPOINT_URL, FASTSAM_CHECKPOINT)
77
- print("FastSAM weights downloaded.")
78
- except Exception as e:
79
- print(f"Error downloading FastSAM weights: {e}")
80
- print("Please ensure the URL is correct and reachable, or manually place the weights file.")
81
- if os.path.exists(FASTSAM_CHECKPOINT):
82
- try: os.remove(FASTSAM_CHECKPOINT)
83
- except OSError: pass
84
- return False
85
  return os.path.exists(FASTSAM_CHECKPOINT)
86
 
87
  def load_fastsam_model():
88
  global fastsam_model
89
  if fastsam_model is None:
90
  if not check_and_import_fastsam():
91
- print("Cannot load FastSAM model because the library couldn't be imported.")
92
- return False # Indicate failure
93
-
94
  if download_fastsam_weights():
95
  try:
96
  print(f"Loading FastSAM model: {FASTSAM_CHECKPOINT}...")
97
  fastsam_model = FastSAM(FASTSAM_CHECKPOINT)
98
- # The FastSAM model itself doesn't need explicit .to(DEVICE)
99
- # It seems to handle device selection internally or via the prompt process
100
- print(f"FastSAM model loaded.")
101
- return True # Indicate success
102
  except Exception as e:
103
  print(f"Error loading FastSAM model: {e}")
104
  traceback.print_exc()
 
105
  else:
106
- print("FastSAM weights not found or download failed. Cannot load model.")
107
- return fastsam_model is not None # Return True if already loaded or loaded successfully
108
-
109
 
110
  # --- Processing Functions ---
111
 
112
- # (Keep run_clip_zero_shot and run_fastsam_segmentation as they were for the other tabs)
113
- # CLIP Zero-Shot Classification Function
114
  def run_clip_zero_shot(image: Image.Image, text_labels: str):
115
- # Load CLIP if needed
116
  if clip_model is None or clip_processor is None:
117
  if not load_clip_model():
118
- return "Error: CLIP Model could not be loaded. Check logs.", None
119
-
120
- if image is None: return "Please upload an image.", None
121
- if not text_labels: return {}, image # Return empty dict, show image
 
122
 
123
  labels = [label.strip() for label in text_labels.split(',') if label.strip()]
124
- if not labels: return {}, image
 
125
 
126
  print(f"Running CLIP zero-shot classification with labels: {labels}")
127
  try:
128
- if image.mode != "RGB": image = image.convert("RGB")
 
129
  inputs = clip_processor(text=labels, images=image, return_tensors="pt", padding=True).to(DEVICE)
130
  with torch.no_grad():
131
  outputs = clip_model(**inputs)
132
  probs = outputs.logits_per_image.softmax(dim=1)
133
- print("CLIP processing complete.")
134
  confidences = {labels[i]: float(probs[0, i].item()) for i in range(len(labels))}
135
  return confidences, image
136
  except Exception as e:
137
  print(f"Error during CLIP processing: {e}")
138
  traceback.print_exc()
139
- return f"An error occurred during CLIP: {e}", image
140
 
141
- # FastSAM Everything Segmentation Function (for the second tab)
142
  def run_fastsam_segmentation(image_pil: Image.Image, conf_threshold: float = 0.4, iou_threshold: float = 0.9):
143
- if not load_fastsam_model():
144
- return "Error: FastSAM Model not loaded. Check logs."
145
- if not fastsam_lib_imported:
146
- return "Error: FastSAM library not available."
147
- if image_pil is None: return "Please upload an image."
148
 
149
  print("Running FastSAM 'segment everything'...")
150
  try:
151
- if image_pil.mode != "RGB": image_pil = image_pil.convert("RGB")
 
152
  image_np_rgb = np.array(image_pil)
153
 
154
  everything_results = fastsam_model(
155
  image_np_rgb, device=DEVICE, retina_masks=True, imgsz=640,
156
- conf=conf_threshold, iou=iou_threshold,
157
  )
158
  prompt_process = FastSAMPrompt(image_np_rgb, everything_results, device=DEVICE)
159
  ann = prompt_process.everything_prompt()
160
- print(f"FastSAM 'everything' found {len(ann[0]['masks']) if ann and ann[0] and 'masks' in ann[0] else 0} masks.")
161
 
162
- # Plotting
163
  output_image = image_pil.copy()
164
- if ann and ann[0] is not None and 'masks' in ann[0] and len(ann[0]['masks']) > 0:
165
  masks = ann[0]['masks'].cpu().numpy()
 
166
  overlay = Image.new('RGBA', output_image.size, (0, 0, 0, 0))
167
  draw = ImageDraw.Draw(overlay)
168
  for mask in masks:
169
- color = (random.randint(50, 255), random.randint(50, 255), random.randint(50, 255), 128)
170
- mask_image = Image.fromarray((mask * 255).astype(np.uint8), mode='L')
 
171
  draw.bitmap((0, 0), mask_image, fill=color)
172
  output_image = Image.alpha_composite(output_image.convert('RGBA'), overlay).convert('RGB')
173
-
174
- print("FastSAM 'everything' processing complete.")
175
  return output_image
176
-
177
  except Exception as e:
178
  print(f"Error during FastSAM 'everything' processing: {e}")
179
  traceback.print_exc()
180
- return f"An error occurred during FastSAM 'everything': {e}"
181
-
182
 
183
- # --- NEW: Text-Prompted Segmentation Function ---
184
  def run_text_prompted_segmentation(image_pil: Image.Image, text_prompts: str, conf_threshold: float = 0.4, iou_threshold: float = 0.9):
185
- """Segments objects based on text prompts."""
186
  if not load_fastsam_model():
187
- return "Error: FastSAM Model not loaded. Check logs.", "No prompts provided."
188
  if not fastsam_lib_imported:
189
- return "Error: FastSAM library not available.", "FastSAM library error."
190
  if image_pil is None:
191
  return "Please upload an image.", "No image provided."
192
  if not text_prompts:
193
- return image_pil, "Please enter text prompts (e.g., 'person, dog')." # Return original image and message
194
 
195
  prompts = [p.strip() for p in text_prompts.split(',') if p.strip()]
196
  if not prompts:
197
  return image_pil, "No valid text prompts entered."
198
 
199
  print(f"Running FastSAM text-prompted segmentation for: {prompts}")
200
-
201
  try:
202
  if image_pil.mode != "RGB":
203
  image_pil = image_pil.convert("RGB")
204
  image_np_rgb = np.array(image_pil)
205
 
206
- # 1. Run FastSAM once to get all potential results
207
- # NOTE: We might optimize later, but this is the standard way FastSAMPrompt works.
208
  everything_results = fastsam_model(
209
  image_np_rgb, device=DEVICE, retina_masks=True, imgsz=640,
210
- conf=conf_threshold, iou=iou_threshold, verbose=False # Less console spam
211
  )
212
-
213
- # 2. Create the prompt processor
214
  prompt_process = FastSAMPrompt(image_np_rgb, everything_results, device=DEVICE)
215
-
216
- # 3. Use text_prompt for each prompt and collect masks
217
  all_matching_masks = []
218
  found_prompts = []
219
 
220
  for text in prompts:
221
  print(f" Processing prompt: '{text}'")
222
- # Ann is a list of dictionaries, one per image. We have one image.
223
- # Each dict can have 'masks', 'bboxes', 'points'.
224
- # text_prompt filters 'everything_results' based on CLIP-like similarity.
225
- # It might return multiple masks if multiple instances match the text.
226
  ann = prompt_process.text_prompt(text=text)
227
-
228
- if ann and ann[0] is not None and 'masks' in ann[0] and len(ann[0]['masks']) > 0:
229
  num_found = len(ann[0]['masks'])
230
- print(f" Found {num_found} mask(s) matching '{text}'.")
231
  found_prompts.append(f"{text} ({num_found})")
232
- masks = ann[0]['masks'].cpu().numpy() # Get masks as numpy array (N, H, W)
233
- all_matching_masks.extend(masks) # Add the numpy arrays to the list
234
  else:
235
- print(f" No masks found matching '{text}'.")
236
  found_prompts.append(f"{text} (0)")
237
 
238
- # 4. Plot the collected masks
239
  output_image = image_pil.copy()
240
- status_message = f"Found segments for: {', '.join(found_prompts)}" if found_prompts else "No matching segments found for any prompt."
241
-
242
- if not all_matching_masks:
243
- print("No matching masks found for any prompt.")
244
- return output_image, status_message # Return original image if nothing matched
245
-
246
- # Convert list of (H, W) masks to a single (N, H, W) array for consistent processing
247
- masks_np = np.stack(all_matching_masks, axis=0) # Shape (TotalMasks, H, W)
248
-
249
- overlay = Image.new('RGBA', output_image.size, (0, 0, 0, 0))
250
- draw = ImageDraw.Draw(overlay)
251
 
252
- for i in range(masks_np.shape[0]):
253
- mask = masks_np[i] # Shape (H, W), boolean
254
- color = (random.randint(50, 255), random.randint(50, 255), random.randint(50, 255), 150) # RGBA with slightly more alpha
255
- mask_image = Image.fromarray((mask * 255).astype(np.uint8), mode='L')
256
- draw.bitmap((0, 0), mask_image, fill=color)
257
-
258
- output_image = Image.alpha_composite(output_image.convert('RGBA'), overlay).convert('RGB')
 
 
 
 
259
 
260
- print("FastSAM text-prompted processing complete.")
261
  return output_image, status_message
262
-
263
  except Exception as e:
264
  print(f"Error during FastSAM text-prompted processing: {e}")
265
  traceback.print_exc()
266
- return f"An error occurred: {e}", "Error during processing."
267
-
268
 
269
  # --- Gradio Interface ---
270
 
271
  print("Attempting to preload models...")
272
- # load_clip_model() # Load CLIP lazily if needed
273
- load_fastsam_model() # Load FastSAM eagerly
274
- print("Preloading finished (or attempted).")
275
-
276
 
277
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
278
  gr.Markdown("# CLIP & FastSAM Demo")
279
  gr.Markdown("Explore Zero-Shot Classification, 'Segment Everything', and Text-Prompted Segmentation.")
280
 
281
  with gr.Tabs():
282
- # --- CLIP Tab (No changes) ---
283
  with gr.TabItem("CLIP Zero-Shot Classification"):
284
- # ... (keep the existing layout and logic for CLIP) ...
285
- gr.Markdown("Upload an image and provide comma-separated candidate labels (e.g., 'cat, dog, car'). CLIP will predict the probability of the image matching each label.")
286
  with gr.Row():
287
  with gr.Column(scale=1):
288
  clip_input_image = gr.Image(type="pil", label="Input Image")
289
- clip_text_labels = gr.Textbox(label="Comma-Separated Labels", placeholder="e.g., astronaut, moon, dog playing fetch")
290
  clip_button = gr.Button("Run CLIP Classification", variant="primary")
291
  with gr.Column(scale=1):
292
  clip_output_label = gr.Label(label="Classification Probabilities")
@@ -298,70 +259,69 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
298
  )
299
  gr.Examples(
300
  examples=[
301
- ["examples/astronaut.jpg", "astronaut, moon, rover, mountain"],
302
- ["examples/dog_bike.jpg", "dog, bicycle, person, park, grass"],
303
- ["examples/clip_logo.png", "logo, text, graphics, abstract art"],
304
  ],
305
  inputs=[clip_input_image, clip_text_labels],
306
- outputs=[clip_output_label, clip_output_image_display], fn=run_clip_zero_shot, cache_examples=False,
 
 
307
  )
308
 
309
-
310
- # --- FastSAM Everything Tab (No changes) ---
311
  with gr.TabItem("FastSAM Segment Everything"):
312
- # ... (keep the existing layout and logic for segment everything) ...
313
- gr.Markdown("Upload an image. FastSAM will attempt to segment all objects/regions in the image.")
314
- with gr.Row():
315
- with gr.Column(scale=1):
316
- fastsam_input_image_all = gr.Image(type="pil", label="Input Image", elem_id="fastsam_input_all") # Unique elem_id if needed
317
- with gr.Row():
318
- fastsam_conf_all = gr.Slider(minimum=0.1, maximum=1.0, value=0.4, step=0.05, label="Confidence Threshold")
319
- fastsam_iou_all = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="IoU Threshold")
320
- fastsam_button_all = gr.Button("Run FastSAM Segmentation", variant="primary")
321
- with gr.Column(scale=1):
322
- fastsam_output_image_all = gr.Image(type="pil", label="Segmented Image", elem_id="fastsam_output_all")
323
- fastsam_button_all.click(
324
- run_fastsam_segmentation,
325
- inputs=[fastsam_input_image_all, fastsam_conf_all, fastsam_iou_all],
326
- outputs=[fastsam_output_image_all]
327
- )
328
- gr.Examples(
329
- examples=[
330
- ["examples/dogs.jpg", 0.4, 0.9],
331
- ["examples/fruits.jpg", 0.5, 0.8],
332
- ["examples/lion.jpg", 0.45, 0.9],
333
- ],
334
- inputs=[fastsam_input_image_all, fastsam_conf_all, fastsam_iou_all],
335
- outputs=[fastsam_output_image_all], fn=run_fastsam_segmentation, cache_examples=False,
336
- )
337
-
338
- # --- NEW: Text-Prompted Segmentation Tab ---
339
  with gr.TabItem("Text-Prompted Segmentation"):
340
- gr.Markdown("Upload an image and provide comma-separated text prompts (e.g., 'person, dog, backpack'). FastSAM + CLIP (internally) will segment only the objects matching the text.")
341
  with gr.Row():
342
  with gr.Column(scale=1):
343
  prompt_input_image = gr.Image(type="pil", label="Input Image")
344
- prompt_text_input = gr.Textbox(label="Comma-Separated Text Prompts", placeholder="e.g., glasses, watch, t-shirt")
345
- with gr.Row(): # Reuse confidence/IoU sliders if desired
346
  prompt_conf = gr.Slider(minimum=0.1, maximum=1.0, value=0.4, step=0.05, label="Confidence Threshold")
347
  prompt_iou = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="IoU Threshold")
348
  prompt_button = gr.Button("Segment by Text", variant="primary")
349
  with gr.Column(scale=1):
350
  prompt_output_image = gr.Image(type="pil", label="Text-Prompted Segmentation")
351
- prompt_status_message = gr.Textbox(label="Status", interactive=False) # To show which prompts matched
352
-
353
  prompt_button.click(
354
  run_text_prompted_segmentation,
355
  inputs=[prompt_input_image, prompt_text_input, prompt_conf, prompt_iou],
356
- outputs=[prompt_output_image, prompt_status_message] # Map to image and status box
357
  )
358
  gr.Examples(
359
  examples=[
360
  ["examples/dog_bike.jpg", "person, bicycle", 0.4, 0.9],
361
  ["examples/astronaut.jpg", "person, helmet", 0.35, 0.9],
362
- ["examples/dogs.jpg", "dog", 0.4, 0.9], # Should find multiple dogs
363
  ["examples/fruits.jpg", "banana, apple", 0.5, 0.8],
364
- ["examples/teacher.jpg", "person, glasses, blackboard", 0.4, 0.9], # Download this image or use another one with glasses/blackboard
365
  ],
366
  inputs=[prompt_input_image, prompt_text_input, prompt_conf, prompt_iou],
367
  outputs=[prompt_output_image, prompt_status_message],
@@ -369,31 +329,33 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
369
  cache_examples=False,
370
  )
371
 
372
- # Ensure example images exist or are downloaded
373
- # (Keep the existing example download logic, maybe add teacher.jpg if used in examples)
374
  if not os.path.exists("examples"):
375
  os.makedirs("examples")
376
- print("Created 'examples' directory. Attempting to download sample images...")
377
- example_files = {
378
- "astronaut.jpg": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/d1/Astronaut_-_St._Jean_Bay.jpg/640px-Astronaut_-_St._Jean_Bay.jpg",
379
- "dog_bike.jpg": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/gradio/outputs_multimodal.jpg",
380
- "clip_logo.png": "https://raw.githubusercontent.com/openai/CLIP/main/CLIP.png",
381
- "dogs.jpg": "https://raw.githubusercontent.com/ultralytics/assets/main/im/image8.jpg",
382
- "fruits.jpg": "https://raw.githubusercontent.com/ultralytics/assets/main/im/image9.jpg",
383
- "lion.jpg": "https://huggingface.co/spaces/gradio/image-segmentation/resolve/main/images/lion.jpg",
384
- "teacher.jpg": "https://images.pexels.com/photos/848117/pexels-photo-848117.jpeg?auto=compress&cs=tinysrgb&w=600" # Example with glasses/board
385
- }
386
- for filename, url in example_files.items():
387
- filepath = os.path.join("examples", filename)
388
- if not os.path.exists(filepath):
389
- try:
390
- print(f"Downloading {filename}...")
391
- wget.download(url, filepath)
392
- except Exception as e:
393
- print(f"Could not download {filename} from {url}: {e}")
394
- print("Example image download attempt finished.")
395
-
396
-
397
- # Launch the Gradio app
 
 
 
398
  if __name__ == "__main__":
399
- demo.launch(debug=True) # debug=True is helpful locally
 
1
  import gradio as gr
2
  import torch
3
+ from transformers import AutoProcessor, AutoModel
4
  from PIL import Image, ImageDraw, ImageFont
5
  import numpy as np
6
  import random
7
  import os
8
+ import wget
9
+ import traceback
10
 
11
  # --- Configuration & Model Loading ---
12
 
13
+ # Device Selection with fallback
14
+ DEVICE = "cuda" if torch.cuda.is_available() and torch.cuda.current_device() >= 0 else "cpu"
 
 
15
  print(f"Using device: {DEVICE}")
16
 
17
+ # --- CLIP Setup ---
18
  CLIP_MODEL_ID = "openai/clip-vit-base-patch32"
19
  clip_processor = None
20
  clip_model = None
 
28
  print("CLIP processor loaded.")
29
  except Exception as e:
30
  print(f"Error loading CLIP processor: {e}")
31
+ return False
32
  if clip_model is None:
33
  try:
34
  print(f"Loading CLIP model: {CLIP_MODEL_ID}...")
 
36
  print(f"CLIP model loaded to {DEVICE}.")
37
  except Exception as e:
38
  print(f"Error loading CLIP model: {e}")
39
+ return False
40
+ return True
 
41
 
42
  # --- FastSAM Setup ---
43
  FASTSAM_CHECKPOINT = "FastSAM-s.pt"
44
  FASTSAM_CHECKPOINT_URL = f"https://huggingface.co/CASIA-IVA-Lab/FastSAM-s/resolve/main/{FASTSAM_CHECKPOINT}"
45
 
46
  fastsam_model = None
47
+ fastsam_lib_imported = False
48
 
49
  def check_and_import_fastsam():
50
  global fastsam_lib_imported
51
  if not fastsam_lib_imported:
52
  try:
53
  from fastsam import FastSAM, FastSAMPrompt
54
+ globals()['FastSAM'] = FastSAM
55
  globals()['FastSAMPrompt'] = FastSAMPrompt
56
  fastsam_lib_imported = True
57
  print("fastsam library imported successfully.")
58
+ except ImportError as e:
59
+ print(f"Error: 'fastsam' library not found. Install with 'pip install fastsam': {e}")
 
60
  fastsam_lib_imported = False
61
  except Exception as e:
62
+ print(f"Unexpected error during fastsam import: {e}")
63
  traceback.print_exc()
64
  fastsam_lib_imported = False
65
  return fastsam_lib_imported
66
 
67
+ def download_fastsam_weights(retries=3):
 
68
  if not os.path.exists(FASTSAM_CHECKPOINT):
69
  print(f"Downloading FastSAM weights: {FASTSAM_CHECKPOINT} from {FASTSAM_CHECKPOINT_URL}...")
70
+ for attempt in range(retries):
71
+ try:
72
+ wget.download(FASTSAM_CHECKPOINT_URL, FASTSAM_CHECKPOINT)
73
+ print("FastSAM weights downloaded.")
74
+ break
75
+ except Exception as e:
76
+ print(f"Attempt {attempt + 1}/{retries} failed: {e}")
77
+ if attempt + 1 == retries:
78
+ print("Failed to download weights after all attempts.")
79
+ return False
80
  return os.path.exists(FASTSAM_CHECKPOINT)
81
 
82
  def load_fastsam_model():
83
  global fastsam_model
84
  if fastsam_model is None:
85
  if not check_and_import_fastsam():
86
+ print("Cannot load FastSAM model due to library import failure.")
87
+ return False
 
88
  if download_fastsam_weights():
89
  try:
90
  print(f"Loading FastSAM model: {FASTSAM_CHECKPOINT}...")
91
  fastsam_model = FastSAM(FASTSAM_CHECKPOINT)
92
+ print("FastSAM model loaded.")
93
+ return True
 
 
94
  except Exception as e:
95
  print(f"Error loading FastSAM model: {e}")
96
  traceback.print_exc()
97
+ return False
98
  else:
99
+ print("FastSAM weights not found or download failed.")
100
+ return False
101
+ return True
102
 
103
  # --- Processing Functions ---
104
 
 
 
105
  def run_clip_zero_shot(image: Image.Image, text_labels: str):
 
106
  if clip_model is None or clip_processor is None:
107
  if not load_clip_model():
108
+ return "Error: CLIP Model could not be loaded.", None
109
+ if image is None:
110
+ return "Please upload an image.", None
111
+ if not text_labels:
112
+ return {}, image
113
 
114
  labels = [label.strip() for label in text_labels.split(',') if label.strip()]
115
+ if not labels:
116
+ return {}, image
117
 
118
  print(f"Running CLIP zero-shot classification with labels: {labels}")
119
  try:
120
+ if image.mode != "RGB":
121
+ image = image.convert("RGB")
122
  inputs = clip_processor(text=labels, images=image, return_tensors="pt", padding=True).to(DEVICE)
123
  with torch.no_grad():
124
  outputs = clip_model(**inputs)
125
  probs = outputs.logits_per_image.softmax(dim=1)
 
126
  confidences = {labels[i]: float(probs[0, i].item()) for i in range(len(labels))}
127
  return confidences, image
128
  except Exception as e:
129
  print(f"Error during CLIP processing: {e}")
130
  traceback.print_exc()
131
+ return f"Error: {e}", image
132
 
 
133
  def run_fastsam_segmentation(image_pil: Image.Image, conf_threshold: float = 0.4, iou_threshold: float = 0.9):
134
+ if not load_fastsam_model() or not fastsam_lib_imported:
135
+ return "Error: FastSAM not loaded or library unavailable."
136
+ if image_pil is None:
137
+ return "Please upload an image."
 
138
 
139
  print("Running FastSAM 'segment everything'...")
140
  try:
141
+ if image_pil.mode != "RGB":
142
+ image_pil = image_pil.convert("RGB")
143
  image_np_rgb = np.array(image_pil)
144
 
145
  everything_results = fastsam_model(
146
  image_np_rgb, device=DEVICE, retina_masks=True, imgsz=640,
147
+ conf=conf_threshold, iou=iou_threshold, verbose=True
148
  )
149
  prompt_process = FastSAMPrompt(image_np_rgb, everything_results, device=DEVICE)
150
  ann = prompt_process.everything_prompt()
 
151
 
 
152
  output_image = image_pil.copy()
153
+ if ann and ann[0] and 'masks' in ann[0] and len(ann[0]['masks']) > 0:
154
  masks = ann[0]['masks'].cpu().numpy()
155
+ print(f"Found {len(masks)} masks with shape: {masks.shape}")
156
  overlay = Image.new('RGBA', output_image.size, (0, 0, 0, 0))
157
  draw = ImageDraw.Draw(overlay)
158
  for mask in masks:
159
+ mask = (mask > 0).astype(np.uint8) * 255
160
+ color = (random.randint(50, 255), random.randint(50, 255), random.randint(50, 255), 180)
161
+ mask_image = Image.fromarray(mask, mode='L')
162
  draw.bitmap((0, 0), mask_image, fill=color)
163
  output_image = Image.alpha_composite(output_image.convert('RGBA'), overlay).convert('RGB')
164
+ else:
165
+ print("No masks detected in 'segment everything' mode.")
166
  return output_image
 
167
  except Exception as e:
168
  print(f"Error during FastSAM 'everything' processing: {e}")
169
  traceback.print_exc()
170
+ return f"Error: {e}"
 
171
 
 
172
  def run_text_prompted_segmentation(image_pil: Image.Image, text_prompts: str, conf_threshold: float = 0.4, iou_threshold: float = 0.9):
 
173
  if not load_fastsam_model():
174
+ return "Error: FastSAM Model not loaded.", "Model load failure."
175
  if not fastsam_lib_imported:
176
+ return "Error: FastSAM library not available.", "Library import error."
177
  if image_pil is None:
178
  return "Please upload an image.", "No image provided."
179
  if not text_prompts:
180
+ return image_pil, "Please enter text prompts (e.g., 'person, dog')."
181
 
182
  prompts = [p.strip() for p in text_prompts.split(',') if p.strip()]
183
  if not prompts:
184
  return image_pil, "No valid text prompts entered."
185
 
186
  print(f"Running FastSAM text-prompted segmentation for: {prompts}")
 
187
  try:
188
  if image_pil.mode != "RGB":
189
  image_pil = image_pil.convert("RGB")
190
  image_np_rgb = np.array(image_pil)
191
 
 
 
192
  everything_results = fastsam_model(
193
  image_np_rgb, device=DEVICE, retina_masks=True, imgsz=640,
194
+ conf=conf_threshold, iou=iou_threshold, verbose=True
195
  )
 
 
196
  prompt_process = FastSAMPrompt(image_np_rgb, everything_results, device=DEVICE)
 
 
197
  all_matching_masks = []
198
  found_prompts = []
199
 
200
  for text in prompts:
201
  print(f" Processing prompt: '{text}'")
 
 
 
 
202
  ann = prompt_process.text_prompt(text=text)
203
+ if ann and ann[0] and 'masks' in ann[0] and len(ann[0]['masks']) > 0:
 
204
  num_found = len(ann[0]['masks'])
205
+ print(f" Found {num_found} mask(s) with shape: {ann[0]['masks'].shape}")
206
  found_prompts.append(f"{text} ({num_found})")
207
+ masks = ann[0]['masks'].cpu().numpy()
208
+ all_matching_masks.extend(masks)
209
  else:
210
+ print(f" No masks found for '{text}'.")
211
  found_prompts.append(f"{text} (0)")
212
 
 
213
  output_image = image_pil.copy()
214
+ status_message = f"Found segments for: {', '.join(found_prompts)}" if found_prompts else "No matches found."
 
 
 
 
 
 
 
 
 
 
215
 
216
+ if all_matching_masks:
217
+ masks_np = np.stack(all_matching_masks, axis=0)
218
+ print(f"Total masks stacked: {masks_np.shape}")
219
+ overlay = Image.new('RGBA', output_image.size, (0, 0, 0, 0))
220
+ draw = ImageDraw.Draw(overlay)
221
+ for mask in masks_np:
222
+ mask = (mask > 0).astype(np.uint8) * 255
223
+ color = (random.randint(50, 255), random.randint(50, 255), random.randint(50, 255), 180)
224
+ mask_image = Image.fromarray(mask, mode='L')
225
+ draw.bitmap((0, 0), mask_image, fill=color)
226
+ output_image = Image.alpha_composite(output_image.convert('RGBA'), overlay).convert('RGB')
227
 
 
228
  return output_image, status_message
 
229
  except Exception as e:
230
  print(f"Error during FastSAM text-prompted processing: {e}")
231
  traceback.print_exc()
232
+ return image_pil, f"Error: {e}"
 
233
 
234
  # --- Gradio Interface ---
235
 
236
  print("Attempting to preload models...")
237
+ load_fastsam_model() # Load FastSAM eagerly
238
+ print("Preloading finished.")
 
 
239
 
240
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
241
  gr.Markdown("# CLIP & FastSAM Demo")
242
  gr.Markdown("Explore Zero-Shot Classification, 'Segment Everything', and Text-Prompted Segmentation.")
243
 
244
  with gr.Tabs():
 
245
  with gr.TabItem("CLIP Zero-Shot Classification"):
246
+ gr.Markdown("Upload an image and provide comma-separated labels (e.g., 'cat, dog, car').")
 
247
  with gr.Row():
248
  with gr.Column(scale=1):
249
  clip_input_image = gr.Image(type="pil", label="Input Image")
250
+ clip_text_labels = gr.Textbox(label="Comma-Separated Labels", placeholder="e.g., astronaut, moon")
251
  clip_button = gr.Button("Run CLIP Classification", variant="primary")
252
  with gr.Column(scale=1):
253
  clip_output_label = gr.Label(label="Classification Probabilities")
 
259
  )
260
  gr.Examples(
261
  examples=[
262
+ ["examples/astronaut.jpg", "astronaut, moon, rover"],
263
+ ["examples/dog_bike.jpg", "dog, bicycle, person"],
264
+ ["examples/clip_logo.png", "logo, text, graphics"],
265
  ],
266
  inputs=[clip_input_image, clip_text_labels],
267
+ outputs=[clip_output_label, clip_output_image_display],
268
+ fn=run_clip_zero_shot,
269
+ cache_examples=False,
270
  )
271
 
 
 
272
  with gr.TabItem("FastSAM Segment Everything"):
273
+ gr.Markdown("Upload an image to segment all objects/regions.")
274
+ with gr.Row():
275
+ with gr.Column(scale=1):
276
+ fastsam_input_image_all = gr.Image(type="pil", label="Input Image")
277
+ with gr.Row():
278
+ fastsam_conf_all = gr.Slider(minimum=0.1, maximum=1.0, value=0.4, step=0.05, label="Confidence Threshold")
279
+ fastsam_iou_all = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="IoU Threshold")
280
+ fastsam_button_all = gr.Button("Run FastSAM Segmentation", variant="primary")
281
+ with gr.Column(scale=1):
282
+ fastsam_output_image_all = gr.Image(type="pil", label="Segmented Image")
283
+ fastsam_button_all.click(
284
+ run_fastsam_segmentation,
285
+ inputs=[fastsam_input_image_all, fastsam_conf_all, fastsam_iou_all],
286
+ outputs=[fastsam_output_image_all]
287
+ )
288
+ gr.Examples(
289
+ examples=[
290
+ ["examples/dogs.jpg", 0.4, 0.9],
291
+ ["examples/fruits.jpg", 0.5, 0.8],
292
+ ["examples/lion.jpg", 0.45, 0.9],
293
+ ],
294
+ inputs=[fastsam_input_image_all, fastsam_conf_all, fastsam_iou_all],
295
+ outputs=[fastsam_output_image_all],
296
+ fn=run_fastsam_segmentation,
297
+ cache_examples=False,
298
+ )
299
+
300
  with gr.TabItem("Text-Prompted Segmentation"):
301
+ gr.Markdown("Upload an image and provide comma-separated prompts (e.g., 'person, dog').")
302
  with gr.Row():
303
  with gr.Column(scale=1):
304
  prompt_input_image = gr.Image(type="pil", label="Input Image")
305
+ prompt_text_input = gr.Textbox(label="Comma-Separated Text Prompts", placeholder="e.g., glasses, watch")
306
+ with gr.Row():
307
  prompt_conf = gr.Slider(minimum=0.1, maximum=1.0, value=0.4, step=0.05, label="Confidence Threshold")
308
  prompt_iou = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="IoU Threshold")
309
  prompt_button = gr.Button("Segment by Text", variant="primary")
310
  with gr.Column(scale=1):
311
  prompt_output_image = gr.Image(type="pil", label="Text-Prompted Segmentation")
312
+ prompt_status_message = gr.Textbox(label="Status", interactive=False)
 
313
  prompt_button.click(
314
  run_text_prompted_segmentation,
315
  inputs=[prompt_input_image, prompt_text_input, prompt_conf, prompt_iou],
316
+ outputs=[prompt_output_image, prompt_status_message]
317
  )
318
  gr.Examples(
319
  examples=[
320
  ["examples/dog_bike.jpg", "person, bicycle", 0.4, 0.9],
321
  ["examples/astronaut.jpg", "person, helmet", 0.35, 0.9],
322
+ ["examples/dogs.jpg", "dog", 0.4, 0.9],
323
  ["examples/fruits.jpg", "banana, apple", 0.5, 0.8],
324
+ ["examples/teacher.jpg", "person, glasses", 0.4, 0.9],
325
  ],
326
  inputs=[prompt_input_image, prompt_text_input, prompt_conf, prompt_iou],
327
  outputs=[prompt_output_image, prompt_status_message],
 
329
  cache_examples=False,
330
  )
331
 
332
+ # Download example images with retries
 
333
  if not os.path.exists("examples"):
334
  os.makedirs("examples")
335
+ print("Created 'examples' directory.")
336
+ example_files = {
337
+ "astronaut.jpg": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/d1/Astronaut_-_St._Jean_Bay.jpg/640px-Astronaut_-_St._Jean_Bay.jpg",
338
+ "dog_bike.jpg": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/gradio/outputs_multimodal.jpg",
339
+ "clip_logo.png": "https://raw.githubusercontent.com/openai/CLIP/main/CLIP.png",
340
+ "dogs.jpg": "https://raw.githubusercontent.com/ultralytics/assets/main/im/image8.jpg",
341
+ "fruits.jpg": "https://raw.githubusercontent.com/ultralytics/assets/main/im/image9.jpg",
342
+ "lion.jpg": "https://huggingface.co/spaces/gradio/image-segmentation/resolve/main/images/lion.jpg",
343
+ "teacher.jpg": "https://images.pexels.com/photos/848117/pexels-photo-848117.jpeg?auto=compress&cs=tinysrgb&w=600"
344
+ }
345
+ def download_example_file(filename, url, retries=3):
346
+ filepath = os.path.join("examples", filename)
347
+ if not os.path.exists(filepath):
348
+ for attempt in range(retries):
349
+ try:
350
+ print(f"Downloading {filename} (attempt {attempt + 1}/{retries})...")
351
+ wget.download(url, filepath)
352
+ break
353
+ except Exception as e:
354
+ print(f"Attempt {attempt + 1} failed: {e}")
355
+ if attempt + 1 == retries:
356
+ print(f"Failed to download {filename} after {retries} attempts.")
357
+ for filename, url in example_files.items():
358
+ download_example_file(filename, url)
359
+
360
  if __name__ == "__main__":
361
+ demo.launch(debug=True)