File size: 13,947 Bytes
dfdcd97
a3ee867
b066832
 
fd55cab
b066832
eefe5b4
b066832
eba2946
b066832
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eba2946
 
b066832
 
eba2946
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b066832
 
 
eba2946
b066832
 
 
 
 
 
eba2946
 
 
 
 
 
b066832
 
 
 
 
 
eba2946
 
 
 
 
b066832
eba2946
b066832
 
eba2946
b066832
 
eba2946
b066832
eba2946
c95f3e0
3cd1243
b066832
6facde6
b066832
 
 
 
 
 
6facde6
b066832
eba2946
 
 
 
6facde6
eba2946
b066832
eba2946
 
6facde6
b066832
6facde6
b066832
 
 
 
6facde6
b066832
6facde6
b066832
 
eba2946
 
6facde6
b066832
6facde6
b066832
eba2946
 
6facde6
 
b066832
eba2946
 
 
6facde6
 
b066832
 
eba2946
b066832
eba2946
b066832
eba2946
 
 
 
 
 
b066832
eba2946
e0d4d2f
b066832
6facde6
3d6a9c7
b066832
 
 
6facde6
b066832
72f4c5c
b066832
 
eba2946
b066832
 
eba2946
b066832
 
 
6facde6
eba2946
b066832
 
6facde6
eba2946
e31b682
eba2946
b066832
 
eba2946
 
b066832
 
6facde6
b066832
eba2946
 
b066832
 
6facde6
b066832
6facde6
b066832
eba2946
 
6facde6
eba2946
 
 
 
eefe5b4
b066832
eba2946
 
b066832
6facde6
 
e0d4d2f
b066832
 
 
eba2946
b066832
 
 
 
 
 
 
 
 
 
 
 
 
 
eba2946
b066832
 
 
eba2946
b066832
 
 
 
 
 
 
 
 
 
eba2946
b066832
 
 
 
eba2946
eefe5b4
6facde6
b066832
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eba2946
 
b066832
 
 
 
 
eba2946
b066832
 
 
 
eba2946
b066832
6facde6
b066832
 
 
eba2946
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b066832
 
 
 
eba2946
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
import gradio as gr
import torch
from transformers import AutoProcessor, AutoModel
from PIL import Image, ImageDraw, ImageFont
import numpy as np
import random
import os
import wget # To download weights
import traceback # For detailed error printing

# --- Configuration & Model Loading ---

# Device Selection
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {DEVICE}")

# --- CLIP Setup ---
CLIP_MODEL_ID = "openai/clip-vit-base-patch32"
clip_processor = None
clip_model = None

def load_clip_model():
    global clip_processor, clip_model
    if clip_processor is None:
        print(f"Loading CLIP processor: {CLIP_MODEL_ID}...")
        clip_processor = AutoProcessor.from_pretrained(CLIP_MODEL_ID)
        print("CLIP processor loaded.")
    if clip_model is None:
        print(f"Loading CLIP model: {CLIP_MODEL_ID}...")
        clip_model = AutoModel.from_pretrained(CLIP_MODEL_ID).to(DEVICE)
        print(f"CLIP model loaded to {DEVICE}.")

# --- FastSAM Setup ---
FASTSAM_CHECKPOINT = "FastSAM-s.pt"
# Use the official model hub repo URL
FASTSAM_CHECKPOINT_URL = f"https://huggingface.co/CASIA-IVA-Lab/FastSAM-s/resolve/main/{FASTSAM_CHECKPOINT}"

fastsam_model = None
fastsam_lib_imported = False # Flag to check if import worked

def check_and_import_fastsam():
    global fastsam_lib_imported
    if not fastsam_lib_imported:
        try:
            from fastsam import FastSAM, FastSAMPrompt
            globals()['FastSAM'] = FastSAM # Make classes available globally
            globals()['FastSAMPrompt'] = FastSAMPrompt
            fastsam_lib_imported = True
            print("fastsam library imported successfully.")
        except ImportError:
            print("Error: 'fastsam' library not found or import failed.")
            print("Please ensure 'fastsam' is installed correctly (pip install fastsam).")
            fastsam_lib_imported = False
        except Exception as e:
            print(f"An unexpected error occurred during fastsam import: {e}")
            fastsam_lib_imported = False
    return fastsam_lib_imported


def download_fastsam_weights():
    if not os.path.exists(FASTSAM_CHECKPOINT):
        print(f"Downloading FastSAM weights: {FASTSAM_CHECKPOINT} from {FASTSAM_CHECKPOINT_URL}...")
        try:
            wget.download(FASTSAM_CHECKPOINT_URL, FASTSAM_CHECKPOINT)
            print("FastSAM weights downloaded.")
        except Exception as e:
            print(f"Error downloading FastSAM weights: {e}")
            print("Please ensure the URL is correct and reachable, or manually place the weights file.")
            # Attempt to remove partially downloaded file if exists
            if os.path.exists(FASTSAM_CHECKPOINT):
                 try:
                     os.remove(FASTSAM_CHECKPOINT)
                 except OSError:
                     pass # Ignore removal errors
            return False
    return os.path.exists(FASTSAM_CHECKPOINT)

def load_fastsam_model():
    global fastsam_model
    if fastsam_model is None:
        if not check_and_import_fastsam(): # Check import first
             print("Cannot load FastSAM model because the library couldn't be imported.")
             return # Exit if import failed

        if download_fastsam_weights(): # Check download/existence second
            try:
                # FastSAM class should be available via globals() now
                print(f"Loading FastSAM model: {FASTSAM_CHECKPOINT}...")
                fastsam_model = FastSAM(FASTSAM_CHECKPOINT)
                print(f"FastSAM model loaded.") # Device handled internally by FastSAM
            except Exception as e:
                print(f"Error loading FastSAM model: {e}")
                traceback.print_exc()
        else:
            print("FastSAM weights not found or download failed. Cannot load model.")


# --- Processing Functions ---

# CLIP Zero-Shot Classification Function
def run_clip_zero_shot(image: Image.Image, text_labels: str):
    if clip_model is None or clip_processor is None:
        load_clip_model() # Attempt to load if not already loaded
        if clip_model is None:
             return "Error: CLIP Model not loaded. Check logs.", None

    if image is None:
        return "Please upload an image.", None # Return None for the image display
    if not text_labels:
        # Return empty results but display the uploaded image
        return {}, image

    labels = [label.strip() for label in text_labels.split(',') if label.strip()] # Ensure non-empty labels
    if not labels:
         # Return empty results but display the uploaded image
         return {}, image

    print(f"Running CLIP zero-shot classification with labels: {labels}")

    try:
        # Ensure image is RGB
        if image.mode != "RGB":
            image = image.convert("RGB")

        inputs = clip_processor(text=labels, images=image, return_tensors="pt", padding=True).to(DEVICE)

        with torch.no_grad():
            outputs = clip_model(**inputs)
            logits_per_image = outputs.logits_per_image
            probs = logits_per_image.softmax(dim=1)

        print("CLIP processing complete.")

        confidences = {labels[i]: float(probs[0, i].item()) for i in range(len(labels))}
        # Return results and the original image used for prediction
        return confidences, image

    except Exception as e:
        print(f"Error during CLIP processing: {e}")
        traceback.print_exc()
        # Return error message and the original image
        return f"An error occurred during CLIP: {e}", image


# FastSAM Segmentation Function
def run_fastsam_segmentation(image_pil: Image.Image, conf_threshold: float = 0.4, iou_threshold: float = 0.9):
    # Ensure model is loaded or attempt to load
    if fastsam_model is None:
        load_fastsam_model()
        if fastsam_model is None:
             # Return error message string for the image component (Gradio handles this)
             return "Error: FastSAM Model not loaded. Check logs."
    # Ensure library was imported
    if not fastsam_lib_imported:
        return "Error: FastSAM library not available. Cannot run segmentation."

    if image_pil is None:
        return "Please upload an image."

    print("Running FastSAM segmentation...")

    try:
         # Ensure image is RGB
        if image_pil.mode != "RGB":
            image_pil = image_pil.convert("RGB")

        image_np_rgb = np.array(image_pil)

        # Run FastSAM inference
        everything_results = fastsam_model(
            image_np_rgb,
            device=DEVICE,
            retina_masks=True,
            imgsz=640,
            conf=conf_threshold,
            iou=iou_threshold,
        )

        # FastSAMPrompt should be available via globals() if import succeeded
        prompt_process = FastSAMPrompt(image_np_rgb, everything_results, device=DEVICE)
        ann = prompt_process.everything_prompt()

        print(f"FastSAM found {len(ann[0]['masks']) if ann and ann[0] and 'masks' in ann[0] else 0} masks.")

        # --- Plotting Masks on Image ---
        output_image = image_pil.copy()
        if ann and ann[0] is not None and 'masks' in ann[0] and len(ann[0]['masks']) > 0:
            masks = ann[0]['masks'].cpu().numpy() # (N, H, W) boolean

            overlay = Image.new('RGBA', output_image.size, (0, 0, 0, 0))
            draw = ImageDraw.Draw(overlay)

            for i in range(masks.shape[0]):
                mask = masks[i]
                color = (random.randint(50, 255), random.randint(50, 255), random.randint(50, 255), 128) # RGBA
                mask_image = Image.fromarray((mask * 255).astype(np.uint8), mode='L')
                draw.bitmap((0,0), mask_image, fill=color)

            output_image = Image.alpha_composite(output_image.convert('RGBA'), overlay).convert('RGB')

        print("FastSAM processing and plotting complete.")
        # *** FIX: Return ONLY the output image for the single Image component ***
        return output_image

    except NameError as ne:
         print(f"NameError during FastSAM processing: {ne}. Was the fastsam library imported correctly?")
         traceback.print_exc()
         return f"A NameError occurred: {ne}. Check library import."
    except Exception as e:
        print(f"Error during FastSAM processing: {e}")
        traceback.print_exc()
        return f"An error occurred during FastSAM: {e}"


# --- Gradio Interface ---

# Pre-load models on startup (optional but good for performance)
print("Attempting to preload models...")
load_clip_model()
load_fastsam_model() # This will now also attempt download/check import
print("Preloading finished (or attempted).")


with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown("# CLIP & FastSAM Demo")
    gr.Markdown("Explore Zero-Shot Classification with CLIP and 'Segment Anything' with FastSAM.")

    with gr.Tabs():
        # --- CLIP Tab ---
        with gr.TabItem("CLIP Zero-Shot Classification"):
            gr.Markdown("Upload an image and provide comma-separated candidate labels (e.g., 'cat, dog, car'). CLIP will predict the probability of the image matching each label.")
            with gr.Row():
                with gr.Column(scale=1):
                    clip_input_image = gr.Image(type="pil", label="Input Image")
                    clip_text_labels = gr.Textbox(label="Comma-Separated Labels", placeholder="e.g., astronaut, moon, dog playing fetch")
                    clip_button = gr.Button("Run CLIP Classification", variant="primary")
                with gr.Column(scale=1):
                    clip_output_label = gr.Label(label="Classification Probabilities")
                    clip_output_image_display = gr.Image(type="pil", label="Input Image Preview")

            clip_button.click(
                run_clip_zero_shot,
                inputs=[clip_input_image, clip_text_labels],
                outputs=[clip_output_label, clip_output_image_display]
            )
            gr.Examples(
                examples=[
                    ["examples/astronaut.jpg", "astronaut, moon, rover, mountain"],
                    ["examples/dog_bike.jpg", "dog, bicycle, person, park, grass"],
                    ["examples/clip_logo.png", "logo, text, graphics, abstract art"], # Added another example
                ],
                inputs=[clip_input_image, clip_text_labels],
                outputs=[clip_output_label, clip_output_image_display],
                fn=run_clip_zero_shot,
                cache_examples=False,
            )

        # --- FastSAM Tab ---
        with gr.TabItem("FastSAM Segmentation"):
            gr.Markdown("Upload an image. FastSAM will attempt to segment all objects/regions in the image.")
            with gr.Row():
                with gr.Column(scale=1):
                    fastsam_input_image = gr.Image(type="pil", label="Input Image")
                    with gr.Row():
                        fastsam_conf = gr.Slider(minimum=0.1, maximum=1.0, value=0.4, step=0.05, label="Confidence Threshold")
                        fastsam_iou = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="IoU Threshold")
                    fastsam_button = gr.Button("Run FastSAM Segmentation", variant="primary")
                with gr.Column(scale=1):
                    fastsam_output_image = gr.Image(type="pil", label="Segmented Image")

            fastsam_button.click(
                run_fastsam_segmentation,
                inputs=[fastsam_input_image, fastsam_conf, fastsam_iou],
                # Output is now correctly mapped to the single component
                outputs=[fastsam_output_image]
            )
            gr.Examples(
                examples=[
                    ["examples/dogs.jpg", 0.4, 0.9],
                    ["examples/fruits.jpg", 0.5, 0.8],
                    ["examples/lion.jpg", 0.45, 0.9], # Added another example
                ],
                inputs=[fastsam_input_image, fastsam_conf, fastsam_iou],
                outputs=[fastsam_output_image],
                fn=run_fastsam_segmentation,
                cache_examples=False,
            )

    # Add example images (optional, but helpful)
    if not os.path.exists("examples"):
        os.makedirs("examples")
        print("Created 'examples' directory. Attempting to download sample images...")
        example_files = {
            "astronaut.jpg": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/d astronaut_-_St._Jean_Bay.jpg/640px-Astronaut_-_St._Jean_Bay.jpg", # Find suitable public domain/CC image
            "dog_bike.jpg": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/gradio/outputs_multimodal.jpg", # Using a relevant example from HF
            "clip_logo.png": "https://raw.githubusercontent.com/openai/CLIP/main/CLIP.png",
            "dogs.jpg": "https://raw.githubusercontent.com/ultralytics/assets/main/im/image8.jpg", # From Ultralytics assets
            "fruits.jpg": "https://raw.githubusercontent.com/ultralytics/assets/main/im/image9.jpg", # From Ultralytics assets
            "lion.jpg": "https://huggingface.co/spaces/gradio/image-segmentation/resolve/main/images/lion.jpg"
        }
        for filename, url in example_files.items():
             filepath = os.path.join("examples", filename)
             if not os.path.exists(filepath):
                 try:
                     print(f"Downloading {filename}...")
                     wget.download(url, filepath)
                 except Exception as e:
                     print(f"Could not download {filename} from {url}: {e}")
        print("Example image download attempt finished.")


# Launch the Gradio app
if __name__ == "__main__":
    # share=True is primarily for local testing to get a public link.
    # Not needed/used when deploying on Hugging Face Spaces.
    # debug=True is helpful for development. Set to False for production.
    demo.launch(debug=True)