sagar007 commited on
Commit
2d0f294
·
verified ·
1 Parent(s): 03c5849

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +388 -116
app.py CHANGED
@@ -11,7 +11,7 @@ import traceback
11
  # --- Configuration & Model Loading ---
12
 
13
  # Device Selection with fallback
14
- DEVICE = "cuda" if torch.cuda.is_available() and torch.cuda.current_device() >= 0 else "cpu"
15
  print(f"Using device: {DEVICE}")
16
 
17
  # --- CLIP Setup ---
@@ -28,6 +28,7 @@ def load_clip_model():
28
  print("CLIP processor loaded.")
29
  except Exception as e:
30
  print(f"Error loading CLIP processor: {e}")
 
31
  return False
32
  if clip_model is None:
33
  try:
@@ -36,6 +37,7 @@ def load_clip_model():
36
  print(f"CLIP model loaded to {DEVICE}.")
37
  except Exception as e:
38
  print(f"Error loading CLIP model: {e}")
 
39
  return False
40
  return True
41
 
@@ -45,18 +47,21 @@ FASTSAM_CHECKPOINT_URL = f"https://huggingface.co/CASIA-IVA-Lab/FastSAM-s/resolv
45
 
46
  fastsam_model = None
47
  fastsam_lib_imported = False
 
 
48
 
49
  def check_and_import_fastsam():
50
- global fastsam_lib_imported
51
  if not fastsam_lib_imported:
52
  try:
53
- from fastsam import FastSAM, FastSAMPrompt
54
- globals()['FastSAM'] = FastSAM
55
- globals()['FastSAMPrompt'] = FastSAMPrompt
56
  fastsam_lib_imported = True
57
  print("fastsam library imported successfully.")
58
  except ImportError as e:
59
- print(f"Error: 'fastsam' library not found. Install with 'pip install fastsam': {e}")
 
60
  fastsam_lib_imported = False
61
  except Exception as e:
62
  print(f"Unexpected error during fastsam import: {e}")
@@ -69,15 +74,25 @@ def download_fastsam_weights(retries=3):
69
  print(f"Downloading FastSAM weights: {FASTSAM_CHECKPOINT} from {FASTSAM_CHECKPOINT_URL}...")
70
  for attempt in range(retries):
71
  try:
 
 
72
  wget.download(FASTSAM_CHECKPOINT_URL, FASTSAM_CHECKPOINT)
73
  print("FastSAM weights downloaded.")
74
- break
75
  except Exception as e:
76
- print(f"Attempt {attempt + 1}/{retries} failed: {e}")
 
 
 
 
 
77
  if attempt + 1 == retries:
78
  print("Failed to download weights after all attempts.")
79
  return False
80
- return os.path.exists(FASTSAM_CHECKPOINT)
 
 
 
81
 
82
  def load_fastsam_model():
83
  global fastsam_model
@@ -86,96 +101,257 @@ def load_fastsam_model():
86
  print("Cannot load FastSAM model due to library import failure.")
87
  return False
88
  if download_fastsam_weights():
 
 
 
 
89
  try:
90
  print(f"Loading FastSAM model: {FASTSAM_CHECKPOINT}...")
 
91
  fastsam_model = FastSAM(FASTSAM_CHECKPOINT)
 
 
 
92
  print("FastSAM model loaded.")
93
  return True
94
  except Exception as e:
95
- print(f"Error loading FastSAM model: {e}")
96
  traceback.print_exc()
97
  return False
98
  else:
99
  print("FastSAM weights not found or download failed.")
100
  return False
 
101
  return True
102
 
103
  # --- Processing Functions ---
104
 
105
  def run_clip_zero_shot(image: Image.Image, text_labels: str):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  if clip_model is None or clip_processor is None:
107
  if not load_clip_model():
 
108
  return "Error: CLIP Model could not be loaded.", None
109
- if image is None:
110
- return "Please upload an image.", None
111
  if not text_labels:
 
112
  return {}, image
113
 
114
  labels = [label.strip() for label in text_labels.split(',') if label.strip()]
115
  if not labels:
 
116
  return {}, image
117
 
118
  print(f"Running CLIP zero-shot classification with labels: {labels}")
119
  try:
 
120
  if image.mode != "RGB":
 
121
  image = image.convert("RGB")
 
122
  inputs = clip_processor(text=labels, images=image, return_tensors="pt", padding=True).to(DEVICE)
123
  with torch.no_grad():
124
  outputs = clip_model(**inputs)
125
- probs = outputs.logits_per_image.softmax(dim=1)
 
 
 
 
126
  confidences = {labels[i]: float(probs[0, i].item()) for i in range(len(labels))}
 
 
127
  return confidences, image
128
  except Exception as e:
129
  print(f"Error during CLIP processing: {e}")
130
  traceback.print_exc()
131
- return f"Error: {e}", image
 
 
132
 
133
  def run_fastsam_segmentation(image_pil: Image.Image, conf_threshold: float = 0.4, iou_threshold: float = 0.9):
134
- if not load_fastsam_model() or not fastsam_lib_imported:
135
- return "Error: FastSAM not loaded or library unavailable."
136
- if image_pil is None:
137
- return "Please upload an image."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
 
139
- print("Running FastSAM 'segment everything'...")
140
  try:
 
141
  if image_pil.mode != "RGB":
142
- image_pil = image_pil.convert("RGB")
143
- image_np_rgb = np.array(image_pil)
 
 
 
 
 
 
144
 
 
 
145
  everything_results = fastsam_model(
146
- image_np_rgb, device=DEVICE, retina_masks=True, imgsz=640,
147
- conf=conf_threshold, iou=iou_threshold, verbose=True
 
 
 
 
 
148
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
149
  prompt_process = FastSAMPrompt(image_np_rgb, everything_results, device=DEVICE)
150
- ann = prompt_process.everything_prompt()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
 
 
 
152
  output_image = image_pil.copy()
153
- if ann and ann[0] and 'masks' in ann[0] and len(ann[0]['masks']) > 0:
154
- masks = ann[0]['masks'].cpu().numpy()
155
- print(f"Found {len(masks)} masks with shape: {masks.shape}")
 
156
  overlay = Image.new('RGBA', output_image.size, (0, 0, 0, 0))
157
  draw = ImageDraw.Draw(overlay)
158
- for mask in masks:
159
- mask = (mask > 0).astype(np.uint8) * 255
160
- color = (random.randint(50, 255), random.randint(50, 255), random.randint(50, 255), 180)
161
- mask_image = Image.fromarray(mask, mode='L')
162
- draw.bitmap((0, 0), mask_image, fill=color)
163
- output_image = Image.alpha_composite(output_image.convert('RGBA'), overlay).convert('RGB')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
  else:
165
- print("No masks detected in 'segment everything' mode.")
166
- return output_image
 
 
 
 
 
 
 
 
 
 
 
 
 
167
  except Exception as e:
168
  print(f"Error during FastSAM 'everything' processing: {e}")
169
  traceback.print_exc()
170
- return f"Error: {e}"
 
 
171
 
172
  def run_text_prompted_segmentation(image_pil: Image.Image, text_prompts: str, conf_threshold: float = 0.4, iou_threshold: float = 0.9):
173
- if not load_fastsam_model():
174
- return "Error: FastSAM Model not loaded.", "Model load failure."
175
- if not fastsam_lib_imported:
176
- return "Error: FastSAM library not available.", "Library import error."
177
- if image_pil is None:
178
- return "Please upload an image.", "No image provided."
 
 
 
 
 
 
 
 
 
 
179
  if not text_prompts:
180
  return image_pil, "Please enter text prompts (e.g., 'person, dog')."
181
 
@@ -183,92 +359,158 @@ def run_text_prompted_segmentation(image_pil: Image.Image, text_prompts: str, co
183
  if not prompts:
184
  return image_pil, "No valid text prompts entered."
185
 
186
- print(f"Running FastSAM text-prompted segmentation for: {prompts}")
 
 
 
187
  try:
 
188
  if image_pil.mode != "RGB":
189
- image_pil = image_pil.convert("RGB")
190
- image_np_rgb = np.array(image_pil)
 
 
 
 
 
191
 
 
192
  everything_results = fastsam_model(
193
- image_np_rgb, device=DEVICE, retina_masks=True, imgsz=640,
194
  conf=conf_threshold, iou=iou_threshold, verbose=True
195
  )
 
 
 
 
 
 
 
 
 
 
196
  prompt_process = FastSAMPrompt(image_np_rgb, everything_results, device=DEVICE)
 
197
  all_matching_masks = []
198
- found_prompts = []
199
 
 
200
  for text in prompts:
201
  print(f" Processing prompt: '{text}'")
 
202
  ann = prompt_process.text_prompt(text=text)
203
- if ann and ann[0] and 'masks' in ann[0] and len(ann[0]['masks']) > 0:
204
- num_found = len(ann[0]['masks'])
205
- print(f" Found {num_found} mask(s) with shape: {ann[0]['masks'].shape}")
206
- found_prompts.append(f"{text} ({num_found})")
207
- masks = ann[0]['masks'].cpu().numpy()
208
- all_matching_masks.extend(masks)
 
 
 
 
 
 
 
 
209
  else:
210
- print(f" No masks found for '{text}'.")
211
- found_prompts.append(f"{text} (0)")
 
212
 
 
 
 
213
  output_image = image_pil.copy()
214
- status_message = f"Found segments for: {', '.join(found_prompts)}" if found_prompts else "No matches found."
215
 
 
216
  if all_matching_masks:
217
- masks_np = np.stack(all_matching_masks, axis=0)
218
- print(f"Total masks stacked: {masks_np.shape}")
 
 
 
219
  overlay = Image.new('RGBA', output_image.size, (0, 0, 0, 0))
220
  draw = ImageDraw.Draw(overlay)
221
- for mask in masks_np:
222
- mask = (mask > 0).astype(np.uint8) * 255
 
 
 
 
 
 
223
  color = (random.randint(50, 255), random.randint(50, 255), random.randint(50, 255), 180)
224
- mask_image = Image.fromarray(mask, mode='L')
225
- draw.bitmap((0, 0), mask_image, fill=color)
226
- output_image = Image.alpha_composite(output_image.convert('RGBA'), overlay).convert('RGB')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227
 
228
  return output_image, status_message
 
229
  except Exception as e:
230
  print(f"Error during FastSAM text-prompted processing: {e}")
231
  traceback.print_exc()
232
- return image_pil, f"Error: {e}"
 
233
 
234
  # --- Gradio Interface ---
235
 
236
  print("Attempting to preload models...")
237
- load_fastsam_model() # Load FastSAM eagerly
238
- print("Preloading finished.")
 
 
 
 
 
239
 
240
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
241
  gr.Markdown("# CLIP & FastSAM Demo")
242
  gr.Markdown("Explore Zero-Shot Classification, 'Segment Everything', and Text-Prompted Segmentation.")
243
 
244
  with gr.Tabs():
 
245
  with gr.TabItem("CLIP Zero-Shot Classification"):
246
- gr.Markdown("Upload an image and provide comma-separated labels (e.g., 'cat, dog, car').")
247
- with gr.Row():
248
- with gr.Column(scale=1):
249
- clip_input_image = gr.Image(type="pil", label="Input Image")
250
- clip_text_labels = gr.Textbox(label="Comma-Separated Labels", placeholder="e.g., astronaut, moon")
251
- clip_button = gr.Button("Run CLIP Classification", variant="primary")
252
- with gr.Column(scale=1):
253
- clip_output_label = gr.Label(label="Classification Probabilities")
254
- clip_output_image_display = gr.Image(type="pil", label="Input Image Preview")
255
  clip_button.click(
256
  run_clip_zero_shot,
257
  inputs=[clip_input_image, clip_text_labels],
 
258
  outputs=[clip_output_label, clip_output_image_display]
259
  )
260
- gr.Examples(
261
- examples=[
262
- ["examples/astronaut.jpg", "astronaut, moon, rover"],
263
- ["examples/dog_bike.jpg", "dog, bicycle, person"],
264
- ["examples/clip_logo.png", "logo, text, graphics"],
265
- ],
266
- inputs=[clip_input_image, clip_text_labels],
267
- outputs=[clip_output_label, clip_output_image_display],
268
- fn=run_clip_zero_shot,
269
- cache_examples=False,
270
- )
271
 
 
272
  with gr.TabItem("FastSAM Segment Everything"):
273
  gr.Markdown("Upload an image to segment all objects/regions.")
274
  with gr.Row():
@@ -279,24 +521,35 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
279
  fastsam_iou_all = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="IoU Threshold")
280
  fastsam_button_all = gr.Button("Run FastSAM Segmentation", variant="primary")
281
  with gr.Column(scale=1):
 
282
  fastsam_output_image_all = gr.Image(type="pil", label="Segmented Image")
 
 
 
283
  fastsam_button_all.click(
284
  run_fastsam_segmentation,
285
  inputs=[fastsam_input_image_all, fastsam_conf_all, fastsam_iou_all],
286
- outputs=[fastsam_output_image_all]
 
287
  )
 
 
288
  gr.Examples(
289
- examples=[
290
- ["examples/dogs.jpg", 0.4, 0.9],
291
- ["examples/fruits.jpg", 0.5, 0.8],
292
- ["examples/lion.jpg", 0.45, 0.9],
293
- ],
294
- inputs=[fastsam_input_image_all, fastsam_conf_all, fastsam_iou_all],
295
- outputs=[fastsam_output_image_all],
296
- fn=run_fastsam_segmentation,
297
- cache_examples=False,
298
- )
299
-
 
 
 
 
300
  with gr.TabItem("Text-Prompted Segmentation"):
301
  gr.Markdown("Upload an image and provide comma-separated prompts (e.g., 'person, dog').")
302
  with gr.Row():
@@ -308,28 +561,35 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
308
  prompt_iou = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="IoU Threshold")
309
  prompt_button = gr.Button("Segment by Text", variant="primary")
310
  with gr.Column(scale=1):
 
311
  prompt_output_image = gr.Image(type="pil", label="Text-Prompted Segmentation")
 
312
  prompt_status_message = gr.Textbox(label="Status", interactive=False)
 
313
  prompt_button.click(
314
  run_text_prompted_segmentation,
315
  inputs=[prompt_input_image, prompt_text_input, prompt_conf, prompt_iou],
 
316
  outputs=[prompt_output_image, prompt_status_message]
317
  )
 
318
  gr.Examples(
319
- examples=[
320
- ["examples/dog_bike.jpg", "person, bicycle", 0.4, 0.9],
321
- ["examples/astronaut.jpg", "person, helmet", 0.35, 0.9],
322
- ["examples/dogs.jpg", "dog", 0.4, 0.9],
323
- ["examples/fruits.jpg", "banana, apple", 0.5, 0.8],
324
- ["examples/teacher.jpg", "person, glasses", 0.4, 0.9],
325
- ],
326
- inputs=[prompt_input_image, prompt_text_input, prompt_conf, prompt_iou],
327
- outputs=[prompt_output_image, prompt_status_message],
328
- fn=run_text_prompted_segmentation,
329
- cache_examples=False,
330
- )
331
-
332
- # Download example images with retries
 
 
333
  if not os.path.exists("examples"):
334
  os.makedirs("examples")
335
  print("Created 'examples' directory.")
@@ -345,17 +605,29 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
345
  def download_example_file(filename, url, retries=3):
346
  filepath = os.path.join("examples", filename)
347
  if not os.path.exists(filepath):
 
348
  for attempt in range(retries):
349
  try:
350
- print(f"Downloading {filename} (attempt {attempt + 1}/{retries})...")
351
  wget.download(url, filepath)
352
- break
 
353
  except Exception as e:
354
- print(f"Attempt {attempt + 1} failed: {e}")
 
 
 
355
  if attempt + 1 == retries:
356
  print(f"Failed to download {filename} after {retries} attempts.")
 
 
 
 
357
  for filename, url in example_files.items():
358
  download_example_file(filename, url)
 
 
359
 
 
360
  if __name__ == "__main__":
361
- demo.launch(debug=True)
 
 
11
  # --- Configuration & Model Loading ---
12
 
13
  # Device Selection with fallback
14
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # Simplified check
15
  print(f"Using device: {DEVICE}")
16
 
17
  # --- CLIP Setup ---
 
28
  print("CLIP processor loaded.")
29
  except Exception as e:
30
  print(f"Error loading CLIP processor: {e}")
31
+ traceback.print_exc() # Print traceback
32
  return False
33
  if clip_model is None:
34
  try:
 
37
  print(f"CLIP model loaded to {DEVICE}.")
38
  except Exception as e:
39
  print(f"Error loading CLIP model: {e}")
40
+ traceback.print_exc() # Print traceback
41
  return False
42
  return True
43
 
 
47
 
48
  fastsam_model = None
49
  fastsam_lib_imported = False
50
+ FastSAM = None # Define placeholders
51
+ FastSAMPrompt = None # Define placeholders
52
 
53
  def check_and_import_fastsam():
54
+ global fastsam_lib_imported, FastSAM, FastSAMPrompt # Make sure globals are modified
55
  if not fastsam_lib_imported:
56
  try:
57
+ from fastsam import FastSAM as FastSAM_lib, FastSAMPrompt as FastSAMPrompt_lib # Use temp names
58
+ FastSAM = FastSAM_lib # Assign to global
59
+ FastSAMPrompt = FastSAMPrompt_lib # Assign to global
60
  fastsam_lib_imported = True
61
  print("fastsam library imported successfully.")
62
  except ImportError as e:
63
+ print(f"Error: 'fastsam' library not found. Please install it: pip install git+https://github.com/CASIA-IVA-Lab/FastSAM.git")
64
+ print(f"ImportError: {e}")
65
  fastsam_lib_imported = False
66
  except Exception as e:
67
  print(f"Unexpected error during fastsam import: {e}")
 
74
  print(f"Downloading FastSAM weights: {FASTSAM_CHECKPOINT} from {FASTSAM_CHECKPOINT_URL}...")
75
  for attempt in range(retries):
76
  try:
77
+ # Ensure the directory exists if FASTSAM_CHECKPOINT includes a path
78
+ os.makedirs(os.path.dirname(FASTSAM_CHECKPOINT) or '.', exist_ok=True)
79
  wget.download(FASTSAM_CHECKPOINT_URL, FASTSAM_CHECKPOINT)
80
  print("FastSAM weights downloaded.")
81
+ return True # Return True on successful download
82
  except Exception as e:
83
+ print(f"Attempt {attempt + 1}/{retries} failed to download FastSAM weights: {e}")
84
+ if os.path.exists(FASTSAM_CHECKPOINT): # Cleanup partial download
85
+ try:
86
+ os.remove(FASTSAM_CHECKPOINT)
87
+ except OSError:
88
+ pass
89
  if attempt + 1 == retries:
90
  print("Failed to download weights after all attempts.")
91
  return False
92
+ return False # Should not be reached if loop completes, but added for clarity
93
+ else:
94
+ print("FastSAM weights already exist.")
95
+ return True # Weights exist
96
 
97
  def load_fastsam_model():
98
  global fastsam_model
 
101
  print("Cannot load FastSAM model due to library import failure.")
102
  return False
103
  if download_fastsam_weights():
104
+ # Ensure FastSAM class is available (might fail if import failed earlier but file exists)
105
+ if FastSAM is None:
106
+ print("FastSAM class not available, check import status.")
107
+ return False
108
  try:
109
  print(f"Loading FastSAM model: {FASTSAM_CHECKPOINT}...")
110
+ # Instantiate the imported class
111
  fastsam_model = FastSAM(FASTSAM_CHECKPOINT)
112
+ # Move model to device *after* initialization (common practice)
113
+ # Note: Check FastSAM docs if it needs explicit .to(DEVICE) or handles it internally
114
+ # fastsam_model.model.to(DEVICE) # Example if needed, adjust based on FastSAM structure
115
  print("FastSAM model loaded.")
116
  return True
117
  except Exception as e:
118
+ print(f"Error loading FastSAM model weights or initializing: {e}")
119
  traceback.print_exc()
120
  return False
121
  else:
122
  print("FastSAM weights not found or download failed.")
123
  return False
124
+ # Model already loaded
125
  return True
126
 
127
  # --- Processing Functions ---
128
 
129
  def run_clip_zero_shot(image: Image.Image, text_labels: str):
130
+ # Keep CLIP as is, seems less likely to be the primary issue
131
+ if not isinstance(image, Image.Image):
132
+ print(f"CLIP input is not a PIL Image, type: {type(image)}")
133
+ # Try to convert if it's a numpy array (common from Gradio)
134
+ if isinstance(image, np.ndarray):
135
+ try:
136
+ image = Image.fromarray(image)
137
+ print("Converted numpy input to PIL Image for CLIP.")
138
+ except Exception as e:
139
+ print(f"Failed to convert numpy array to PIL Image: {e}")
140
+ return "Error: Invalid image input format.", None
141
+ else:
142
+ return "Error: Please provide a valid image.", None
143
+
144
  if clip_model is None or clip_processor is None:
145
  if not load_clip_model():
146
+ # Return None for the image part on critical error
147
  return "Error: CLIP Model could not be loaded.", None
 
 
148
  if not text_labels:
149
+ # Return empty dict and original image if no labels
150
  return {}, image
151
 
152
  labels = [label.strip() for label in text_labels.split(',') if label.strip()]
153
  if not labels:
154
+ # Return empty dict and original image if no valid labels
155
  return {}, image
156
 
157
  print(f"Running CLIP zero-shot classification with labels: {labels}")
158
  try:
159
+ # Ensure image is RGB
160
  if image.mode != "RGB":
161
+ print(f"Converting image from {image.mode} to RGB for CLIP.")
162
  image = image.convert("RGB")
163
+
164
  inputs = clip_processor(text=labels, images=image, return_tensors="pt", padding=True).to(DEVICE)
165
  with torch.no_grad():
166
  outputs = clip_model(**inputs)
167
+ # Calculate probabilities
168
+ logits_per_image = outputs.logits_per_image # B x N_labels
169
+ probs = logits_per_image.softmax(dim=1) # Softmax over labels
170
+
171
+ # Create confidences dictionary
172
  confidences = {labels[i]: float(probs[0, i].item()) for i in range(len(labels))}
173
+ print(f"CLIP Confidences: {confidences}")
174
+ # Return confidences and the original (potentially converted) image
175
  return confidences, image
176
  except Exception as e:
177
  print(f"Error during CLIP processing: {e}")
178
  traceback.print_exc()
179
+ # Return error message and None for image
180
+ return f"Error during CLIP processing: {e}", None
181
+
182
 
183
  def run_fastsam_segmentation(image_pil: Image.Image, conf_threshold: float = 0.4, iou_threshold: float = 0.9):
184
+ # Add input type check
185
+ if not isinstance(image_pil, Image.Image):
186
+ print(f"FastSAM input is not a PIL Image, type: {type(image_pil)}")
187
+ if isinstance(image_pil, np.ndarray):
188
+ try:
189
+ image_pil = Image.fromarray(image_pil)
190
+ print("Converted numpy input to PIL Image for FastSAM.")
191
+ except Exception as e:
192
+ print(f"Failed to convert numpy array to PIL Image: {e}")
193
+ # Return None for image on error
194
+ return None, "Error: Invalid image input format." # Return tuple for consistency
195
+ else:
196
+ # Return None for image on error
197
+ return None, "Error: Please provide a valid image." # Return tuple
198
+
199
+ # Ensure model is loaded
200
+ if not load_fastsam_model() or not fastsam_lib_imported or FastSAMPrompt is None:
201
+ # Return None for image on critical error
202
+ return None, "Error: FastSAM not loaded or library unavailable."
203
+
204
+ print(f"Running FastSAM 'segment everything' with conf={conf_threshold}, iou={iou_threshold}...")
205
+ output_image = None # Initialize output image
206
+ status_message = "Processing..." # Initial status
207
 
 
208
  try:
209
+ # Ensure image is RGB
210
  if image_pil.mode != "RGB":
211
+ print(f"Converting image from {image_pil.mode} to RGB for FastSAM.")
212
+ image_pil_rgb = image_pil.convert("RGB")
213
+ else:
214
+ image_pil_rgb = image_pil
215
+
216
+ # Convert PIL Image to NumPy array (RGB)
217
+ image_np_rgb = np.array(image_pil_rgb)
218
+ print(f"Input image shape for FastSAM: {image_np_rgb.shape}")
219
 
220
+ # Run FastSAM model
221
+ # Make sure the arguments match what FastSAM expects
222
  everything_results = fastsam_model(
223
+ image_np_rgb,
224
+ device=DEVICE,
225
+ retina_masks=True,
226
+ imgsz=640, # Or another size FastSAM supports
227
+ conf=conf_threshold,
228
+ iou=iou_threshold,
229
+ verbose=True # Keep verbose for debugging
230
  )
231
+
232
+ # Check if results are valid before creating prompt
233
+ if everything_results is None or not isinstance(everything_results, list) or len(everything_results) == 0:
234
+ print("FastSAM model returned None or empty results.")
235
+ # Return original image and status
236
+ return image_pil, "FastSAM did not return valid results."
237
+
238
+ # Results might be in a different format, inspect 'everything_results'
239
+ print(f"Type of everything_results: {type(everything_results)}")
240
+ print(f"Length of everything_results: {len(everything_results)}")
241
+ if len(everything_results) > 0:
242
+ print(f"Type of first element: {type(everything_results[0])}")
243
+ # Try to access potential attributes like 'masks' if it's an object
244
+ if hasattr(everything_results[0], 'masks') and everything_results[0].masks is not None:
245
+ print(f"Masks found in results object, shape: {everything_results[0].masks.data.shape}")
246
+ else:
247
+ print("First result element does not have 'masks' attribute or it's None.")
248
+
249
+
250
+ # Process results with FastSAMPrompt
251
+ # Ensure FastSAMPrompt class is available
252
+ if FastSAMPrompt is None:
253
+ print("FastSAMPrompt class is not available.")
254
+ return image_pil, "Error: FastSAMPrompt class not loaded."
255
+
256
  prompt_process = FastSAMPrompt(image_np_rgb, everything_results, device=DEVICE)
257
+ ann = prompt_process.everything_prompt() # Get all annotations
258
+
259
+ # Check annotation format - Adjust based on actual FastSAM output structure
260
+ # Assuming 'ann' is a list and the first element is a dictionary containing masks
261
+ masks = None
262
+ if isinstance(ann, list) and len(ann) > 0 and isinstance(ann[0], dict) and 'masks' in ann[0]:
263
+ mask_tensor = ann[0]['masks']
264
+ if mask_tensor is not None and mask_tensor.numel() > 0: # Check if tensor is not None and not empty
265
+ masks = mask_tensor.cpu().numpy()
266
+ print(f"Found {len(masks)} masks with shape: {masks.shape}")
267
+ else:
268
+ print("Annotation 'masks' tensor is None or empty.")
269
+ else:
270
+ print(f"No masks found or annotation format unexpected. ann type: {type(ann)}")
271
+ if isinstance(ann, list) and len(ann) > 0:
272
+ print(f"First element of ann: {ann[0]}")
273
 
274
+
275
+ # Prepare output image (start with original)
276
  output_image = image_pil.copy()
277
+
278
+ # Draw masks if found
279
+ if masks is not None and len(masks) > 0:
280
+ # Ensure output_image is RGBA for compositing
281
  overlay = Image.new('RGBA', output_image.size, (0, 0, 0, 0))
282
  draw = ImageDraw.Draw(overlay)
283
+
284
+ for i, mask in enumerate(masks):
285
+ # Ensure mask is boolean/binary before converting
286
+ binary_mask = (mask > 0) # Use threshold 0 for binary mask from FastSAM output
287
+ mask_uint8 = binary_mask.astype(np.uint8) * 255
288
+ if mask_uint8.max() == 0: # Skip empty masks
289
+ # print(f"Skipping empty mask {i}")
290
+ continue
291
+
292
+ color = (random.randint(50, 255), random.randint(50, 255), random.randint(50, 255), 180) # RGBA color
293
+ try:
294
+ mask_image = Image.fromarray(mask_uint8, mode='L') # Grayscale mask
295
+ # Draw the mask onto the overlay
296
+ draw.bitmap((0, 0), mask_image, fill=color)
297
+ except Exception as draw_err:
298
+ print(f"Error drawing mask {i}: {draw_err}")
299
+ traceback.print_exc()
300
+ continue # Skip this mask
301
+
302
+ # Composite the overlay onto the image
303
+ try:
304
+ output_image_rgba = output_image.convert('RGBA')
305
+ output_image_composited = Image.alpha_composite(output_image_rgba, overlay)
306
+ output_image = output_image_composited.convert('RGB') # Convert back to RGB for Gradio
307
+ status_message = f"Segmentation complete. Found {len(masks)} masks."
308
+ print("Mask drawing and compositing finished.")
309
+ except Exception as comp_err:
310
+ print(f"Error during alpha compositing: {comp_err}")
311
+ traceback.print_exc()
312
+ output_image = image_pil # Fallback to original image
313
+ status_message = "Error during mask visualization."
314
+
315
  else:
316
+ print("No masks detected or processed for 'segment everything' mode.")
317
+ status_message = "No segments found or processed."
318
+ output_image = image_pil # Return original image if no masks
319
+
320
+ # Save for debugging before returning
321
+ if output_image:
322
+ try:
323
+ debug_path = "debug_fastsam_everything_output.png"
324
+ output_image.save(debug_path)
325
+ print(f"Saved debug output to {debug_path}")
326
+ except Exception as save_err:
327
+ print(f"Failed to save debug image: {save_err}")
328
+
329
+ return output_image, status_message # Return image and status message
330
+
331
  except Exception as e:
332
  print(f"Error during FastSAM 'everything' processing: {e}")
333
  traceback.print_exc()
334
+ # Return original image and error message in case of failure
335
+ return image_pil, f"Error during processing: {e}"
336
+
337
 
338
  def run_text_prompted_segmentation(image_pil: Image.Image, text_prompts: str, conf_threshold: float = 0.4, iou_threshold: float = 0.9):
339
+ # Add input type check
340
+ if not isinstance(image_pil, Image.Image):
341
+ print(f"FastSAM Text input is not a PIL Image, type: {type(image_pil)}")
342
+ if isinstance(image_pil, np.ndarray):
343
+ try:
344
+ image_pil = Image.fromarray(image_pil)
345
+ print("Converted numpy input to PIL Image for FastSAM Text.")
346
+ except Exception as e:
347
+ print(f"Failed to convert numpy array to PIL Image: {e}")
348
+ return None, "Error: Invalid image input format."
349
+ else:
350
+ return None, "Error: Please provide a valid image."
351
+
352
+ # Ensure model is loaded
353
+ if not load_fastsam_model() or not fastsam_lib_imported or FastSAMPrompt is None:
354
+ return image_pil, "Error: FastSAM Model not loaded or library unavailable." # Return original image on load fail
355
  if not text_prompts:
356
  return image_pil, "Please enter text prompts (e.g., 'person, dog')."
357
 
 
359
  if not prompts:
360
  return image_pil, "No valid text prompts entered."
361
 
362
+ print(f"Running FastSAM text-prompted segmentation for: {prompts} with conf={conf_threshold}, iou={iou_threshold}")
363
+ output_image = None
364
+ status_message = "Processing..."
365
+
366
  try:
367
+ # Ensure image is RGB
368
  if image_pil.mode != "RGB":
369
+ print(f"Converting image from {image_pil.mode} to RGB for FastSAM.")
370
+ image_pil_rgb = image_pil.convert("RGB")
371
+ else:
372
+ image_pil_rgb = image_pil
373
+
374
+ image_np_rgb = np.array(image_pil_rgb)
375
+ print(f"Input image shape for FastSAM Text: {image_np_rgb.shape}")
376
 
377
+ # Run FastSAM once to get all potential segments
378
  everything_results = fastsam_model(
379
+ image_np_rgb, device=DEVICE, retina_masks=True, imgsz=640, # Use consistent args
380
  conf=conf_threshold, iou=iou_threshold, verbose=True
381
  )
382
+
383
+ # Check results
384
+ if everything_results is None or not isinstance(everything_results, list) or len(everything_results) == 0:
385
+ print("FastSAM model returned None or empty results for text prompt base.")
386
+ return image_pil, "FastSAM did not return base results."
387
+
388
+ # Initialize FastSAMPrompt
389
+ if FastSAMPrompt is None:
390
+ print("FastSAMPrompt class is not available.")
391
+ return image_pil, "Error: FastSAMPrompt class not loaded."
392
  prompt_process = FastSAMPrompt(image_np_rgb, everything_results, device=DEVICE)
393
+
394
  all_matching_masks = []
395
+ found_prompts_details = [] # Store details like 'prompt (count)'
396
 
397
+ # Process each text prompt
398
  for text in prompts:
399
  print(f" Processing prompt: '{text}'")
400
+ # Get annotation for the specific text prompt
401
  ann = prompt_process.text_prompt(text=text)
402
+
403
+ # Check annotation format and extract masks
404
+ current_masks = None
405
+ num_found = 0
406
+ # Adjust check based on actual structure of 'ann' for text_prompt
407
+ if isinstance(ann, list) and len(ann) > 0 and isinstance(ann[0], dict) and 'masks' in ann[0]:
408
+ mask_tensor = ann[0]['masks']
409
+ if mask_tensor is not None and mask_tensor.numel() > 0:
410
+ current_masks = mask_tensor.cpu().numpy()
411
+ num_found = len(current_masks)
412
+ print(f" Found {num_found} mask(s) for '{text}'. Shape: {current_masks.shape}")
413
+ all_matching_masks.extend(current_masks) # Add found masks to the list
414
+ else:
415
+ print(f" Annotation 'masks' tensor is None or empty for '{text}'.")
416
  else:
417
+ print(f" No masks found or annotation format unexpected for '{text}'. ann type: {type(ann)}")
418
+ if isinstance(ann, list) and len(ann) > 0:
419
+ print(f" First element of ann for '{text}': {ann[0]}")
420
 
421
+ found_prompts_details.append(f"{text} ({num_found})") # Record count for status
422
+
423
+ # Prepare output image
424
  output_image = image_pil.copy()
425
+ status_message = f"Results: {', '.join(found_prompts_details)}" if found_prompts_details else "No matches found for any prompt."
426
 
427
+ # Draw all collected masks if any were found
428
  if all_matching_masks:
429
+ print(f"Total masks collected across all prompts: {len(all_matching_masks)}")
430
+ # Stack masks if needed (optional, can draw one by one)
431
+ # masks_np = np.stack(all_matching_masks, axis=0)
432
+ # print(f"Total masks stacked shape: {masks_np.shape}")
433
+
434
  overlay = Image.new('RGBA', output_image.size, (0, 0, 0, 0))
435
  draw = ImageDraw.Draw(overlay)
436
+
437
+ for i, mask in enumerate(all_matching_masks): # Iterate through collected masks
438
+ binary_mask = (mask > 0)
439
+ mask_uint8 = binary_mask.astype(np.uint8) * 255
440
+ if mask_uint8.max() == 0:
441
+ continue # Skip empty masks
442
+
443
+ # Assign a unique color per mask or per prompt (using random here)
444
  color = (random.randint(50, 255), random.randint(50, 255), random.randint(50, 255), 180)
445
+ try:
446
+ mask_image = Image.fromarray(mask_uint8, mode='L')
447
+ draw.bitmap((0, 0), mask_image, fill=color)
448
+ except Exception as draw_err:
449
+ print(f"Error drawing collected mask {i}: {draw_err}")
450
+ traceback.print_exc()
451
+ continue
452
+
453
+ # Composite the overlay
454
+ try:
455
+ output_image_rgba = output_image.convert('RGBA')
456
+ output_image_composited = Image.alpha_composite(output_image_rgba, overlay)
457
+ output_image = output_image_composited.convert('RGB')
458
+ print("Text prompt mask drawing and compositing finished.")
459
+ except Exception as comp_err:
460
+ print(f"Error during alpha compositing for text prompts: {comp_err}")
461
+ traceback.print_exc()
462
+ output_image = image_pil # Fallback
463
+ status_message += " (Error during visualization)"
464
+ else:
465
+ print("No matching masks found for any text prompt.")
466
+ # status_message is already set
467
+
468
+ # Save for debugging
469
+ if output_image:
470
+ try:
471
+ debug_path = "debug_fastsam_text_output.png"
472
+ output_image.save(debug_path)
473
+ print(f"Saved debug output to {debug_path}")
474
+ except Exception as save_err:
475
+ print(f"Failed to save debug image: {save_err}")
476
 
477
  return output_image, status_message
478
+
479
  except Exception as e:
480
  print(f"Error during FastSAM text-prompted processing: {e}")
481
  traceback.print_exc()
482
+ # Return original image and error message
483
+ return image_pil, f"Error during processing: {e}"
484
 
485
  # --- Gradio Interface ---
486
 
487
  print("Attempting to preload models...")
488
+ load_clip_model() # Preload CLIP
489
+ load_fastsam_model() # Preload FastSAM
490
+ print("Preloading finished (check logs above for errors).")
491
+
492
+
493
+ # --- Gradio Interface Definition ---
494
+ # (Your Gradio Blocks code remains largely the same, but ensure the outputs match the function returns)
495
 
496
  with gr.Blocks(theme=gr.themes.Soft()) as demo:
497
  gr.Markdown("# CLIP & FastSAM Demo")
498
  gr.Markdown("Explore Zero-Shot Classification, 'Segment Everything', and Text-Prompted Segmentation.")
499
 
500
  with gr.Tabs():
501
+ # --- CLIP Tab ---
502
  with gr.TabItem("CLIP Zero-Shot Classification"):
503
+ # ... (CLIP UI definition - seems ok) ...
 
 
 
 
 
 
 
 
504
  clip_button.click(
505
  run_clip_zero_shot,
506
  inputs=[clip_input_image, clip_text_labels],
507
+ # Output matches: Label (dict/str), Image (PIL/None)
508
  outputs=[clip_output_label, clip_output_image_display]
509
  )
510
+ # ... (CLIP Examples - seems ok) ...
511
+
 
 
 
 
 
 
 
 
 
512
 
513
+ # --- FastSAM Everything Tab ---
514
  with gr.TabItem("FastSAM Segment Everything"):
515
  gr.Markdown("Upload an image to segment all objects/regions.")
516
  with gr.Row():
 
521
  fastsam_iou_all = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="IoU Threshold")
522
  fastsam_button_all = gr.Button("Run FastSAM Segmentation", variant="primary")
523
  with gr.Column(scale=1):
524
+ # Output for the image
525
  fastsam_output_image_all = gr.Image(type="pil", label="Segmented Image")
526
+ # Add a Textbox for status messages/errors
527
+ fastsam_status_all = gr.Textbox(label="Status", interactive=False)
528
+
529
  fastsam_button_all.click(
530
  run_fastsam_segmentation,
531
  inputs=[fastsam_input_image_all, fastsam_conf_all, fastsam_iou_all],
532
+ # Outputs: Image (PIL/None), Status (str)
533
+ outputs=[fastsam_output_image_all, fastsam_status_all] # Updated outputs
534
  )
535
+ # Update examples if needed to match new output structure (add None/str for status)
536
+ # Note: Examples might need adjustment if they expect only image output
537
  gr.Examples(
538
+ examples=[
539
+ ["examples/dogs.jpg", 0.4, 0.9],
540
+ ["examples/fruits.jpg", 0.5, 0.8],
541
+ ["examples/lion.jpg", 0.45, 0.9],
542
+ ],
543
+ inputs=[fastsam_input_image_all, fastsam_conf_all, fastsam_iou_all],
544
+ # Need to adjust outputs for examples if function signature changed
545
+ # This might require a wrapper if examples expect single output
546
+ # For now, comment out example outputs or adjust function signature for examples
547
+ outputs=[fastsam_output_image_all, fastsam_status_all],
548
+ fn=run_fastsam_segmentation,
549
+ cache_examples=False, # Keep False for debugging
550
+ )
551
+
552
+ # --- Text-Prompted Segmentation Tab ---
553
  with gr.TabItem("Text-Prompted Segmentation"):
554
  gr.Markdown("Upload an image and provide comma-separated prompts (e.g., 'person, dog').")
555
  with gr.Row():
 
561
  prompt_iou = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="IoU Threshold")
562
  prompt_button = gr.Button("Segment by Text", variant="primary")
563
  with gr.Column(scale=1):
564
+ # Output Image
565
  prompt_output_image = gr.Image(type="pil", label="Text-Prompted Segmentation")
566
+ # Status Textbox (already exists, correctly)
567
  prompt_status_message = gr.Textbox(label="Status", interactive=False)
568
+
569
  prompt_button.click(
570
  run_text_prompted_segmentation,
571
  inputs=[prompt_input_image, prompt_text_input, prompt_conf, prompt_iou],
572
+ # Outputs: Image (PIL/None), Status (str) - Matches function
573
  outputs=[prompt_output_image, prompt_status_message]
574
  )
575
+ # Update examples similarly if needed
576
  gr.Examples(
577
+ examples=[
578
+ ["examples/dog_bike.jpg", "person, bicycle", 0.4, 0.9],
579
+ ["examples/astronaut.jpg", "person, helmet", 0.35, 0.9],
580
+ ["examples/dogs.jpg", "dog", 0.4, 0.9],
581
+ ["examples/fruits.jpg", "banana, apple", 0.5, 0.8],
582
+ ["examples/teacher.jpg", "person, glasses", 0.4, 0.9],
583
+ ],
584
+ inputs=[prompt_input_image, prompt_text_input, prompt_conf, prompt_iou],
585
+ outputs=[prompt_output_image, prompt_status_message],
586
+ fn=run_text_prompted_segmentation,
587
+ cache_examples=False, # Keep False for debugging
588
+ )
589
+
590
+
591
+ # --- Example File Download ---
592
+ # (Download logic seems okay, ensure 'wget' is installed: pip install wget)
593
  if not os.path.exists("examples"):
594
  os.makedirs("examples")
595
  print("Created 'examples' directory.")
 
605
  def download_example_file(filename, url, retries=3):
606
  filepath = os.path.join("examples", filename)
607
  if not os.path.exists(filepath):
608
+ print(f"Attempting to download {filename}...")
609
  for attempt in range(retries):
610
  try:
 
611
  wget.download(url, filepath)
612
+ print(f"Downloaded {filename} successfully.")
613
+ return # Exit function on success
614
  except Exception as e:
615
+ print(f"Download attempt {attempt + 1}/{retries} for {filename} failed: {e}")
616
+ if os.path.exists(filepath): # Clean up partial download
617
+ try: os.remove(filepath)
618
+ except OSError: pass
619
  if attempt + 1 == retries:
620
  print(f"Failed to download {filename} after {retries} attempts.")
621
+ else:
622
+ print(f"Example file {filename} already exists.")
623
+
624
+ # Trigger downloads
625
  for filename, url in example_files.items():
626
  download_example_file(filename, url)
627
+ print("Example file check/download complete.")
628
+
629
 
630
+ # --- Launch App ---
631
  if __name__ == "__main__":
632
+ print("Launching Gradio Demo...")
633
+ demo.launch(debug=True) # Keep debug=True