File size: 32,105 Bytes
dfdcd97
a3ee867
03c5849
b066832
fd55cab
b066832
eefe5b4
03c5849
 
0747bb5
b066832
 
 
03c5849
0747bb5
b066832
 
03c5849
b066832
 
 
 
 
 
 
23fa119
 
 
 
 
 
0747bb5
03c5849
b066832
23fa119
 
 
 
 
 
0747bb5
03c5849
 
b066832
 
 
eba2946
b066832
 
03c5849
2d0f294
 
eba2946
 
0747bb5
eba2946
0747bb5
 
 
 
 
 
 
 
 
 
 
 
 
 
eba2946
0747bb5
 
 
 
eba2946
 
03c5849
0747bb5
 
 
 
 
 
 
eba2946
 
03c5849
23fa119
eba2946
 
 
03c5849
b066832
eba2946
0747bb5
 
 
 
 
 
 
 
 
 
03c5849
 
 
0747bb5
2d0f294
03c5849
2d0f294
 
 
 
 
 
03c5849
 
 
0747bb5
2d0f294
0747bb5
2d0f294
b066832
 
 
 
0747bb5
23fa119
03c5849
 
0747bb5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
03c5849
2d0f294
0747bb5
03c5849
3cd1243
b066832
6facde6
b066832
0747bb5
 
 
2d0f294
0747bb5
2d0f294
 
 
 
 
 
 
 
 
 
0747bb5
b066832
23fa119
03c5849
0747bb5
 
03c5849
0747bb5
6facde6
23fa119
03c5849
0747bb5
6facde6
b066832
 
2d0f294
03c5849
2d0f294
03c5849
2d0f294
b066832
 
 
0747bb5
 
2d0f294
b066832
2d0f294
eba2946
0747bb5
6facde6
b066832
eba2946
2d0f294
 
6facde6
b066832
0747bb5
 
 
2d0f294
0747bb5
2d0f294
 
 
 
 
 
0747bb5
2d0f294
0747bb5
2d0f294
0747bb5
2d0f294
0747bb5
2d0f294
 
0747bb5
 
23fa119
 
2d0f294
03c5849
2d0f294
 
 
 
 
 
 
eba2946
2d0f294
23fa119
0747bb5
 
23fa119
2d0f294
0747bb5
 
2d0f294
0747bb5
 
 
 
 
2d0f294
0747bb5
 
 
 
 
2d0f294
0747bb5
2d0f294
 
 
 
23fa119
2d0f294
 
0747bb5
2d0f294
0747bb5
2d0f294
 
0747bb5
2d0f294
 
 
0747bb5
2d0f294
 
0747bb5
2d0f294
0747bb5
23fa119
2d0f294
 
 
23fa119
 
0747bb5
2d0f294
0747bb5
2d0f294
0747bb5
2d0f294
0747bb5
2d0f294
0747bb5
2d0f294
0747bb5
2d0f294
 
 
 
0747bb5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
03c5849
2d0f294
 
0747bb5
2d0f294
 
 
 
0747bb5
2d0f294
 
 
0747bb5
2d0f294
23fa119
 
 
0747bb5
2d0f294
23fa119
 
0747bb5
 
 
2d0f294
0747bb5
2d0f294
 
 
 
 
 
 
 
 
 
0747bb5
2d0f294
0747bb5
23fa119
03c5849
e0d4d2f
23fa119
 
 
 
2d0f294
 
 
 
3d6a9c7
2d0f294
b066832
2d0f294
 
 
 
 
 
 
72f4c5c
2d0f294
b066832
0747bb5
 
b066832
2d0f294
 
 
0747bb5
2d0f294
 
 
 
 
b066832
2d0f294
23fa119
0747bb5
23fa119
2d0f294
23fa119
 
 
2d0f294
 
 
0747bb5
2d0f294
 
0747bb5
2d0f294
 
 
0747bb5
2d0f294
0747bb5
23fa119
2d0f294
0747bb5
23fa119
0747bb5
2d0f294
 
b066832
2d0f294
6facde6
2d0f294
03c5849
2d0f294
03c5849
 
0747bb5
2d0f294
0747bb5
2d0f294
 
0747bb5
2d0f294
03c5849
2d0f294
 
 
0747bb5
2d0f294
 
 
 
0747bb5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2d0f294
 
0747bb5
2d0f294
 
 
 
0747bb5
2d0f294
 
23fa119
 
2d0f294
eefe5b4
23fa119
eba2946
2d0f294
6facde6
0747bb5
b066832
0747bb5
 
 
2d0f294
 
 
0747bb5
b066832
0747bb5
 
 
 
 
b066832
 
0747bb5
b066832
0747bb5
599a500
 
0747bb5
599a500
0747bb5
599a500
 
 
0747bb5
599a500
0747bb5
b066832
 
 
 
 
6facde6
0747bb5
 
 
 
 
 
 
 
 
 
 
 
 
23fa119
0747bb5
 
 
 
 
 
 
 
 
 
 
 
2d0f294
0747bb5
03c5849
 
0747bb5
22401e9
03c5849
599a500
0747bb5
 
 
 
 
 
 
 
 
 
 
 
 
23fa119
0747bb5
 
 
 
 
 
 
 
 
 
 
 
 
2d0f294
0747bb5
23fa119
 
0747bb5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b066832
2d0f294
0747bb5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
03c5849
0747bb5
b066832
0747bb5
2d0f294
0747bb5
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
import gradio as gr
import torch
from transformers import AutoProcessor, AutoModel
from PIL import Image, ImageDraw, ImageFont
import numpy as np
import random
import os
import wget
import traceback
import sys # Import sys for checking modules

# --- Configuration & Model Loading ---

# Device Selection with fallback
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {DEVICE}")

# --- CLIP Setup ---
CLIP_MODEL_ID = "openai/clip-vit-base-patch32"
clip_processor = None
clip_model = None

def load_clip_model():
    global clip_processor, clip_model
    if clip_processor is None:
        try:
            print(f"Loading CLIP processor: {CLIP_MODEL_ID}...")
            clip_processor = AutoProcessor.from_pretrained(CLIP_MODEL_ID)
            print("CLIP processor loaded.")
        except Exception as e:
            print(f"Error loading CLIP processor: {e}")
            traceback.print_exc()
            return False
    if clip_model is None:
        try:
            print(f"Loading CLIP model: {CLIP_MODEL_ID}...")
            clip_model = AutoModel.from_pretrained(CLIP_MODEL_ID).to(DEVICE)
            print(f"CLIP model loaded to {DEVICE}.")
        except Exception as e:
            print(f"Error loading CLIP model: {e}")
            traceback.print_exc()
            return False
    return True

# --- FastSAM Setup ---
FASTSAM_CHECKPOINT = "FastSAM-s.pt"
FASTSAM_CHECKPOINT_URL = f"https://huggingface.co/CASIA-IVA-Lab/FastSAM-s/resolve/main/{FASTSAM_CHECKPOINT}"

fastsam_model = None
fastsam_lib_imported = False
FastSAM = None # Define placeholders
FastSAMPrompt = None # Define placeholders

def check_and_import_fastsam():
    global fastsam_lib_imported, FastSAM, FastSAMPrompt
    if not fastsam_lib_imported:
        # Check if ultralytics is installed first, as it's a dependency
        if 'ultralytics' not in sys.modules:
            try:
                # Try importing to trigger potential error if not installed
                import ultralytics
                print("Found 'ultralytics' library.")
            except ImportError:
                print("\n--- ERROR ---")
                print("The 'ultralytics' library (required by FastSAM) is not installed.")
                print("Please install it first: pip install ultralytics")
                print("---------------\n")
                return False # Cannot proceed without ultralytics

        # Now try importing fastsam
        try:
            # Use temporary names to avoid conflict if they exist globally somehow
            from fastsam import FastSAM as FastSAM_lib, FastSAMPrompt as FastSAMPrompt_lib
            FastSAM = FastSAM_lib # Assign to global placeholder
            FastSAMPrompt = FastSAMPrompt_lib # Assign to global placeholder
            fastsam_lib_imported = True
            print("fastsam library imported successfully.")
        except ImportError as e:
            print("\n--- ERROR ---")
            print("The 'fastsam' library was not found or could not be imported.")
            print("Please ensure it is installed correctly:")
            print("  pip install git+https://github.com/CASIA-IVA-Lab/FastSAM.git")
            print(f"(ImportError: {e})")
            print("Also ensure 'ultralytics' is installed: pip install ultralytics")
            print("---------------\n")
            fastsam_lib_imported = False
        except Exception as e:
            print(f"Unexpected error during fastsam import: {e}")
            traceback.print_exc()
            fastsam_lib_imported = False
    return fastsam_lib_imported

def download_fastsam_weights(retries=3):
    if not os.path.exists(FASTSAM_CHECKPOINT):
        print(f"Downloading FastSAM weights: {FASTSAM_CHECKPOINT} from {FASTSAM_CHECKPOINT_URL}...")
        # Ensure the directory exists if FASTSAM_CHECKPOINT includes a path
        checkpoint_dir = os.path.dirname(FASTSAM_CHECKPOINT)
        if checkpoint_dir and not os.path.exists(checkpoint_dir):
             try:
                 os.makedirs(checkpoint_dir)
                 print(f"Created directory for weights: {checkpoint_dir}")
             except OSError as e:
                 print(f"Error creating directory {checkpoint_dir}: {e}")
                 return False

        for attempt in range(retries):
            try:
                wget.download(FASTSAM_CHECKPOINT_URL, FASTSAM_CHECKPOINT)
                print("FastSAM weights downloaded successfully.")
                return True # Return True on successful download
            except Exception as e:
                print(f"Attempt {attempt + 1}/{retries} failed to download FastSAM weights: {e}")
                if os.path.exists(FASTSAM_CHECKPOINT): # Cleanup partial download
                    try:
                        os.remove(FASTSAM_CHECKPOINT)
                    except OSError:
                        pass
                if attempt + 1 == retries:
                    print("Failed to download weights after all attempts.")
                    return False
        return False # Should not be reached if loop completes correctly
    else:
        print(f"FastSAM weights file '{FASTSAM_CHECKPOINT}' already exists.")
        return True # Weights exist

def load_fastsam_model():
    global fastsam_model
    if fastsam_model is None:
        print("Attempting to load FastSAM model...")
        if not check_and_import_fastsam():
            print("Cannot load FastSAM model due to library import failure.")
            return False
        if not download_fastsam_weights():
            print("Cannot load FastSAM model because weights are missing or download failed.")
            return False

        # Ensure FastSAM class is available (double check after import attempt)
        if FastSAM is None:
             print("FastSAM class reference is None, cannot instantiate model.")
             return False

        try:
            print(f"Loading FastSAM model from checkpoint: {FASTSAM_CHECKPOINT}...")
            # Instantiate the imported FastSAM class
            fastsam_model = FastSAM(FASTSAM_CHECKPOINT)
            # Note: FastSAM typically handles device placement internally based on constructor args or method calls.
            # If you face device issues, check FastSAM's documentation for explicit device moving.
            # Example: Some models might need fastsam_model.model.to(DEVICE) - check structure.
            print("FastSAM model loaded successfully.")
            return True
        except Exception as e:
            print(f"Error loading FastSAM model weights or initializing: {e}")
            traceback.print_exc()
            fastsam_model = None # Ensure model is None if loading failed
            return False
    # Model already loaded
    # print("FastSAM model already loaded.") # Optional: uncomment for debugging reuse
    return True

# --- Processing Functions ---

def run_clip_zero_shot(image: Image.Image, text_labels: str):
    # Input validation
    if image is None:
        return "Error: Please upload an image.", None # Return None for image component
    if not isinstance(image, Image.Image):
         print(f"CLIP input is not a PIL Image, type: {type(image)}. Attempting conversion.")
         if isinstance(image, np.ndarray):
             try:
                 image = Image.fromarray(image)
                 print("Converted numpy input to PIL Image for CLIP.")
             except Exception as e:
                 print(f"Failed to convert numpy array to PIL Image: {e}")
                 return "Error: Invalid image input format.", None
         else:
             return "Error: Please provide a valid image.", None

    # Model loading check
    if clip_model is None or clip_processor is None:
        if not load_clip_model():
            return "Error: CLIP Model could not be loaded.", None

    # Label check
    if not text_labels:
        return {}, image # Return empty dict and original image if no labels

    labels = [label.strip() for label in text_labels.split(',') if label.strip()]
    if not labels:
        return {}, image # Return empty dict and original image if no valid labels

    print(f"Running CLIP zero-shot classification with labels: {labels}")
    try:
        # Ensure image is RGB
        if image.mode != "RGB":
            print(f"Converting image from {image.mode} to RGB for CLIP.")
            image = image.convert("RGB")

        inputs = clip_processor(text=labels, images=image, return_tensors="pt", padding=True).to(DEVICE)
        with torch.no_grad():
            outputs = clip_model(**inputs)
            logits_per_image = outputs.logits_per_image
            probs = logits_per_image.softmax(dim=1)

        confidences = {labels[i]: float(probs[0, i].item()) for i in range(len(labels))}
        print(f"CLIP Confidences: {confidences}")
        return confidences, image

    except Exception as e:
        print(f"Error during CLIP processing: {e}")
        traceback.print_exc()
        return f"Error during CLIP processing: {e}", None


def run_fastsam_segmentation(image_pil: Image.Image, conf_threshold: float = 0.4, iou_threshold: float = 0.9):
    # Input validation
    if image_pil is None:
        return None, "Error: Please upload an image."
    if not isinstance(image_pil, Image.Image):
         print(f"FastSAM input is not a PIL Image, type: {type(image_pil)}. Attempting conversion.")
         if isinstance(image_pil, np.ndarray):
             try:
                 image_pil = Image.fromarray(image_pil)
                 print("Converted numpy input to PIL Image for FastSAM.")
             except Exception as e:
                 print(f"Failed to convert numpy array to PIL Image: {e}")
                 return None, "Error: Invalid image input format."
         else:
             return None, "Error: Please provide a valid image."

    # Model loading check
    if not load_fastsam_model() or not fastsam_lib_imported or FastSAMPrompt is None:
        return image_pil, "Error: FastSAM model/library not ready. Check logs." # Return original image if model failed

    print(f"Running FastSAM 'segment everything' with conf={conf_threshold}, iou={iou_threshold}...")
    output_image = None
    status_message = "Processing..."

    try:
        # Ensure image is RGB
        if image_pil.mode != "RGB":
            print(f"Converting image from {image_pil.mode} to RGB for FastSAM.")
            image_pil_rgb = image_pil.convert("RGB")
        else:
            image_pil_rgb = image_pil

        image_np_rgb = np.array(image_pil_rgb)
        print(f"Input image shape for FastSAM: {image_np_rgb.shape}")

        # Run FastSAM model
        everything_results = fastsam_model(
            image_np_rgb, device=DEVICE, retina_masks=True, imgsz=640, # Adjust imgsz if needed
            conf=conf_threshold, iou=iou_threshold, verbose=False # Set verbose=False for cleaner logs unless debugging
        )

        # Check results type and content (FastSAM results format might vary)
        # Typically a list of result objects, or similar structure
        if everything_results is None or not isinstance(everything_results, list) or len(everything_results) == 0:
            print("FastSAM model returned None or empty results list.")
            return image_pil, "FastSAM processing returned no results."

        # Assuming the first result object contains the relevant data
        first_result = everything_results[0]

        # --- IMPORTANT: Inspect the 'first_result' object ---
        # Use print(dir(first_result)), print(type(first_result)) etc. if unsure
        # Common attributes might be .masks, .boxes, .names
        # print(f"Type of first_result: {type(first_result)}")
        # print(f"Attributes of first_result: {dir(first_result)}")

        # Initialize FastSAMPrompt
        if FastSAMPrompt is None:
            print("FastSAMPrompt class is not available.")
            return image_pil, "Error: FastSAMPrompt class not loaded."

        prompt_process = FastSAMPrompt(image_np_rgb, everything_results, device=DEVICE)
        ann = prompt_process.everything_prompt() # Get all annotations

        # Check annotation format - Adapt based on actual FastSAM/FastSAMPrompt output
        masks = None
        # Expected format: list containing a dict with 'masks' tensor
        if isinstance(ann, list) and len(ann) > 0 and isinstance(ann[0], dict) and 'masks' in ann[0]:
            mask_tensor = ann[0]['masks']
            if mask_tensor is not None and isinstance(mask_tensor, torch.Tensor) and mask_tensor.numel() > 0:
                 masks = mask_tensor.cpu().numpy()
                 print(f"Found {len(masks)} masks with shape: {masks.shape}")
            else:
                 print("Annotation 'masks' tensor is None, not a Tensor, or empty.")
        else:
            print(f"No masks found or annotation format unexpected. ann type: {type(ann)}")
            if isinstance(ann, list) and len(ann) > 0: print(f"First element of ann: {ann[0]}")

        # Prepare output image
        output_image = image_pil.copy()

        # Draw masks if found
        if masks is not None and len(masks) > 0:
            overlay = Image.new('RGBA', output_image.size, (0, 0, 0, 0))
            draw = ImageDraw.Draw(overlay)
            valid_masks_drawn = 0
            for i, mask in enumerate(masks):
                binary_mask = (mask > 0) # Use threshold 0 for binary mask
                mask_uint8 = binary_mask.astype(np.uint8) * 255
                if mask_uint8.max() == 0: continue # Skip empty masks

                color = (random.randint(50, 255), random.randint(50, 255), random.randint(50, 255), 180)
                try:
                    mask_image = Image.fromarray(mask_uint8, mode='L')
                    draw.bitmap((0, 0), mask_image, fill=color)
                    valid_masks_drawn += 1
                except Exception as draw_err:
                    print(f"Error drawing mask {i}: {draw_err}")
                    traceback.print_exc()

            if valid_masks_drawn > 0:
                try:
                     output_image_rgba = output_image.convert('RGBA')
                     output_image_composited = Image.alpha_composite(output_image_rgba, overlay)
                     output_image = output_image_composited.convert('RGB')
                     status_message = f"Segmentation complete. Found and drew {valid_masks_drawn} masks."
                     print("Mask drawing and compositing finished.")
                except Exception as comp_err:
                     print(f"Error during alpha compositing: {comp_err}")
                     traceback.print_exc()
                     output_image = image_pil # Fallback
                     status_message = f"Found {valid_masks_drawn} masks, but error during visualization."
            else:
                 status_message = f"Found {len(masks)} masks initially, but none were valid for drawing."
                 output_image = image_pil # Return original if no valid masks drawn
        else:
            print("No masks detected or processed for 'segment everything' mode.")
            status_message = "No segments found or processed."
            output_image = image_pil # Return original image

        # Save for debugging before returning
        if output_image:
             try:
                 output_image.save("debug_fastsam_everything_output.png")
             except Exception as save_err:
                 print(f"Failed to save debug image: {save_err}")

        return output_image, status_message

    except Exception as e:
        print(f"Error during FastSAM 'everything' processing: {e}")
        traceback.print_exc()
        return image_pil, f"Error during processing: {e}" # Return original image and error


def run_text_prompted_segmentation(image_pil: Image.Image, text_prompts: str, conf_threshold: float = 0.4, iou_threshold: float = 0.9):
    # Input validation
    if image_pil is None:
        return None, "Error: Please upload an image."
    if not isinstance(image_pil, Image.Image):
         print(f"FastSAM Text input is not a PIL Image, type: {type(image_pil)}. Attempting conversion.")
         if isinstance(image_pil, np.ndarray):
             try:
                 image_pil = Image.fromarray(image_pil)
                 print("Converted numpy input to PIL Image for FastSAM Text.")
             except Exception as e:
                 print(f"Failed to convert numpy array to PIL Image: {e}")
                 return None, "Error: Invalid image input format."
         else:
             return None, "Error: Please provide a valid image."

    # Model loading check
    if not load_fastsam_model() or not fastsam_lib_imported or FastSAMPrompt is None:
        return image_pil, "Error: FastSAM model/library not ready. Check logs."
    if not text_prompts:
        return image_pil, "Please enter text prompts (e.g., 'person, dog')."

    prompts = [p.strip() for p in text_prompts.split(',') if p.strip()]
    if not prompts:
        return image_pil, "No valid text prompts entered."

    print(f"Running FastSAM text-prompted segmentation for: {prompts} with conf={conf_threshold}, iou={iou_threshold}")
    output_image = None
    status_message = "Processing..."

    try:
        # Ensure image is RGB
        if image_pil.mode != "RGB":
            print(f"Converting image from {image_pil.mode} to RGB for FastSAM.")
            image_pil_rgb = image_pil.convert("RGB")
        else:
            image_pil_rgb = image_pil

        image_np_rgb = np.array(image_pil_rgb)
        print(f"Input image shape for FastSAM Text: {image_np_rgb.shape}")

        # Run FastSAM once to get all potential segments
        everything_results = fastsam_model(
            image_np_rgb, device=DEVICE, retina_masks=True, imgsz=640,
            conf=conf_threshold, iou=iou_threshold, verbose=False # Set verbose=False usually
        )

        if everything_results is None or not isinstance(everything_results, list) or len(everything_results) == 0:
            print("FastSAM model returned None or empty results for text prompt base.")
            return image_pil, "FastSAM did not return base results needed for text prompting."

        # Initialize FastSAMPrompt
        if FastSAMPrompt is None:
             print("FastSAMPrompt class is not available.")
             return image_pil, "Error: FastSAMPrompt class not loaded."
        prompt_process = FastSAMPrompt(image_np_rgb, everything_results, device=DEVICE)

        all_matching_masks = []
        found_prompts_details = []

        # Process each text prompt
        for text in prompts:
            print(f"  Processing prompt: '{text}'")
            ann = prompt_process.text_prompt(text=text)

            current_masks = None
            num_found = 0
            # Check annotation format - adapt based on text_prompt output structure
            if isinstance(ann, list) and len(ann) > 0 and isinstance(ann[0], dict) and 'masks' in ann[0]:
                 mask_tensor = ann[0]['masks']
                 if mask_tensor is not None and isinstance(mask_tensor, torch.Tensor) and mask_tensor.numel() > 0:
                    current_masks = mask_tensor.cpu().numpy()
                    num_found = len(current_masks)
                    print(f"    Found {num_found} mask(s) for '{text}'. Shape: {current_masks.shape}")
                    all_matching_masks.extend(current_masks) # Add found masks
                 else:
                    print(f"    Annotation 'masks' tensor is None, not a Tensor, or empty for '{text}'.")
            else:
                 print(f"    No masks found or annotation format unexpected for '{text}'. ann type: {type(ann)}")
                 if isinstance(ann, list) and len(ann) > 0: print(f"    First element of ann for '{text}': {ann[0]}")

            found_prompts_details.append(f"{text} ({num_found})")

        # Prepare output image
        output_image = image_pil.copy()
        status_message = f"Results: {', '.join(found_prompts_details)}" if found_prompts_details else "No matches found for any prompt."

        # Draw all collected masks if any were found
        if all_matching_masks:
            print(f"Total masks collected across all prompts: {len(all_matching_masks)}")
            overlay = Image.new('RGBA', output_image.size, (0, 0, 0, 0))
            draw = ImageDraw.Draw(overlay)
            valid_masks_drawn = 0

            for i, mask in enumerate(all_matching_masks):
                binary_mask = (mask > 0)
                mask_uint8 = binary_mask.astype(np.uint8) * 255
                if mask_uint8.max() == 0: continue

                color = (random.randint(50, 255), random.randint(50, 255), random.randint(50, 255), 180)
                try:
                    mask_image = Image.fromarray(mask_uint8, mode='L')
                    draw.bitmap((0, 0), mask_image, fill=color)
                    valid_masks_drawn += 1
                except Exception as draw_err:
                    print(f"Error drawing collected mask {i}: {draw_err}")
                    traceback.print_exc()

            if valid_masks_drawn > 0:
                try:
                    output_image_rgba = output_image.convert('RGBA')
                    output_image_composited = Image.alpha_composite(output_image_rgba, overlay)
                    output_image = output_image_composited.convert('RGB')
                    print("Text prompt mask drawing and compositing finished.")
                    # Append drawing status if needed
                    if valid_masks_drawn < len(all_matching_masks):
                        status_message += f" (Drew {valid_masks_drawn}/{len(all_matching_masks)} found masks)"
                except Exception as comp_err:
                    print(f"Error during alpha compositing for text prompts: {comp_err}")
                    traceback.print_exc()
                    output_image = image_pil # Fallback
                    status_message += " (Error during visualization)"
            else:
                output_image = image_pil # Return original if no masks drawn
                status_message += " (No valid masks to draw)"
        else:
            print("No matching masks found for any text prompt.")
            output_image = image_pil # Return original image

        # Save for debugging
        if output_image:
             try:
                 output_image.save("debug_fastsam_text_output.png")
             except Exception as save_err:
                 print(f"Failed to save debug image: {save_err}")

        return output_image, status_message

    except Exception as e:
        print(f"Error during FastSAM text-prompted processing: {e}")
        traceback.print_exc()
        return image_pil, f"Error during processing: {e}"

# --- Preload Models ---
print("Attempting to preload models...")
load_clip_model()
load_fastsam_model() # Try to load FastSAM eagerly
print("Preloading finished (check logs above for success/errors).")


# --- Gradio Interface Definition ---
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown("# CLIP & FastSAM Demo")
    gr.Markdown("Explore Zero-Shot Classification, 'Segment Everything', and Text-Prompted Segmentation.")
    gr.Markdown("---")
    gr.Markdown("**NOTE:** Ensure required libraries are installed: `pip install --upgrade gradio torch transformers Pillow numpy wget ultralytics` and `pip install git+https://github.com/CASIA-IVA-Lab/FastSAM.git`")
    gr.Markdown("---")


    with gr.Tabs():
        # --- CLIP Tab ---
        with gr.TabItem("CLIP Zero-Shot Classification"):
            gr.Markdown("Upload an image and provide comma-separated labels (e.g., 'cat, dog, car').")
            with gr.Row():
                with gr.Column(scale=1):
                    # Define UI elements first
                    clip_input_image = gr.Image(type="pil", label="Input Image")
                    clip_text_labels = gr.Textbox(label="Comma-Separated Labels", placeholder="e.g., astronaut, moon")
                    clip_button = gr.Button("Run CLIP Classification", variant="primary")
                with gr.Column(scale=1):
                    clip_output_label = gr.Label(label="Classification Probabilities")
                    clip_output_image_display = gr.Image(type="pil", label="Input Image Preview", interactive=False)

            # Define the click handler AFTER elements are defined
            clip_button.click(
                run_clip_zero_shot,
                inputs=[clip_input_image, clip_text_labels],
                outputs=[clip_output_label, clip_output_image_display]
            )

            gr.Examples(
                examples=[
                    ["examples/astronaut.jpg", "astronaut, moon, rover"],
                    ["examples/dog_bike.jpg", "dog, bicycle, person"],
                    ["examples/clip_logo.png", "logo, text, graphics"],
                ],
                inputs=[clip_input_image, clip_text_labels],
                outputs=[clip_output_label, clip_output_image_display],
                fn=run_clip_zero_shot,
                cache_examples=False, # Keep False during debugging
            )

        # --- FastSAM Everything Tab ---
        with gr.TabItem("FastSAM Segment Everything"):
            gr.Markdown("Upload an image to segment all objects/regions.")
            with gr.Row():
                with gr.Column(scale=1):
                    # Define UI elements first
                    fastsam_input_image_all = gr.Image(type="pil", label="Input Image")
                    with gr.Row():
                        fastsam_conf_all = gr.Slider(minimum=0.1, maximum=1.0, value=0.4, step=0.05, label="Confidence Threshold")
                        fastsam_iou_all = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="IoU Threshold")
                    fastsam_button_all = gr.Button("Run FastSAM Segmentation", variant="primary")
                with gr.Column(scale=1):
                    fastsam_output_image_all = gr.Image(type="pil", label="Segmented Image", interactive=False)
                    fastsam_status_all = gr.Textbox(label="Status", interactive=False)

            # Define the click handler AFTER elements are defined
            fastsam_button_all.click(
                run_fastsam_segmentation,
                inputs=[fastsam_input_image_all, fastsam_conf_all, fastsam_iou_all], # Correct inputs list
                outputs=[fastsam_output_image_all, fastsam_status_all]
            )

            gr.Examples(
                 examples=[
                     ["examples/dogs.jpg", 0.4, 0.9],
                     ["examples/fruits.jpg", 0.5, 0.8],
                     ["examples/lion.jpg", 0.45, 0.9],
                 ],
                 inputs=[fastsam_input_image_all, fastsam_conf_all, fastsam_iou_all],
                 outputs=[fastsam_output_image_all, fastsam_status_all],
                 fn=run_fastsam_segmentation,
                 cache_examples=False,
             )

        # --- Text-Prompted Segmentation Tab ---
        with gr.TabItem("Text-Prompted Segmentation"):
            gr.Markdown("Upload an image and provide comma-separated prompts (e.g., 'person, dog').")
            with gr.Row():
                with gr.Column(scale=1):
                    # Define UI elements first
                    prompt_input_image = gr.Image(type="pil", label="Input Image")
                    prompt_text_input = gr.Textbox(label="Comma-Separated Text Prompts", placeholder="e.g., glasses, watch")
                    with gr.Row():
                        prompt_conf = gr.Slider(minimum=0.1, maximum=1.0, value=0.4, step=0.05, label="Confidence Threshold")
                        prompt_iou = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="IoU Threshold")
                    prompt_button = gr.Button("Segment by Text", variant="primary")
                with gr.Column(scale=1):
                    prompt_output_image = gr.Image(type="pil", label="Text-Prompted Segmentation", interactive=False)
                    prompt_status_message = gr.Textbox(label="Status", interactive=False)

            # Define the click handler AFTER elements are defined
            prompt_button.click(
                run_text_prompted_segmentation,
                inputs=[prompt_input_image, prompt_text_input, prompt_conf, prompt_iou], # Correct inputs list
                outputs=[prompt_output_image, prompt_status_message]
            )

            gr.Examples(
                examples=[
                    ["examples/dog_bike.jpg", "person, bicycle", 0.4, 0.9],
                    ["examples/astronaut.jpg", "person, helmet", 0.35, 0.9],
                    ["examples/dogs.jpg", "dog", 0.4, 0.9],
                    ["examples/fruits.jpg", "banana, apple", 0.5, 0.8],
                    ["examples/teacher.jpg", "person, glasses", 0.4, 0.9],
                ],
                inputs=[prompt_input_image, prompt_text_input, prompt_conf, prompt_iou],
                outputs=[prompt_output_image, prompt_status_message],
                fn=run_text_prompted_segmentation,
                cache_examples=False,
            )

# --- Example File Download ---
# (This logic should be outside the `with gr.Blocks...` block)
if not os.path.exists("examples"):
    try:
        os.makedirs("examples")
        print("Created 'examples' directory.")
    except OSError as e:
        print(f"Error creating 'examples' directory: {e}")

example_files = {
    "astronaut.jpg": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/d1/Astronaut_-_St._Jean_Bay.jpg/640px-Astronaut_-_St._Jean_Bay.jpg",
    "dog_bike.jpg": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/gradio/outputs_multimodal.jpg",
    "clip_logo.png": "https://raw.githubusercontent.com/openai/CLIP/main/CLIP.png",
    "dogs.jpg": "https://raw.githubusercontent.com/ultralytics/assets/main/im/image8.jpg",
    "fruits.jpg": "https://raw.githubusercontent.com/ultralytics/assets/main/im/image9.jpg",
    "lion.jpg": "https://huggingface.co/spaces/gradio/image-segmentation/resolve/main/images/lion.jpg",
    "teacher.jpg": "https://images.pexels.com/photos/848117/pexels-photo-848117.jpeg?auto=compress&cs=tinysrgb&w=600"
}

def download_example_file(filename, url, retries=3):
    filepath = os.path.join("examples", filename)
    if not os.path.exists(filepath):
        print(f"Attempting to download {filename}...")
        for attempt in range(retries):
            try:
                wget.download(url, filepath)
                print(f"Downloaded {filename} successfully.")
                return # Exit function on success
            except Exception as e:
                print(f"Download attempt {attempt + 1}/{retries} for {filename} failed: {e}")
                if os.path.exists(filepath): # Clean up partial download
                    try: os.remove(filepath)
                    except OSError: pass
                if attempt + 1 == retries:
                    print(f"Failed to download {filename} after {retries} attempts.")
    # else: # Optional: uncomment if you want confirmation for existing files
    #      print(f"Example file {filename} already exists.")

# Trigger downloads if directory exists
if os.path.exists("examples"):
    for filename, url in example_files.items():
        download_example_file(filename, url)
    print("Example file check/download process complete.")
else:
    print("Skipping example download because 'examples' directory could not be created.")


# --- Launch App ---
if __name__ == "__main__":
    print("-----------------------------------------")
    print("Launching Gradio Demo...")
    print("Ensure FastSAM model and weights are correctly loaded (check logs above).")
    print("If FastSAM fails, check installation: pip install ultralytics && pip install git+https://github.com/CASIA-IVA-Lab/FastSAM.git")
    print("-----------------------------------------")
    demo.launch(debug=True) # Keep debug=True for detailed Gradio errors