File size: 8,663 Bytes
dfdcd97
a3ee867
c95f3e0
fd55cab
 
 
 
3701938
eefe5b4
6facde6
 
c95f3e0
6facde6
3cd1243
6facde6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e0d4d2f
fd55cab
6facde6
3701938
 
eefe5b4
 
6facde6
3d6a9c7
 
 
 
6facde6
3d6a9c7
 
6facde6
3d6a9c7
 
 
 
 
 
 
6facde6
3d6a9c7
 
 
 
6facde6
3d6a9c7
 
 
 
 
72f4c5c
6facde6
3701938
6facde6
 
eefe5b4
3d6a9c7
 
 
 
 
6facde6
eefe5b4
6facde6
 
 
 
 
 
eefe5b4
 
6facde6
 
 
 
eefe5b4
 
 
6facde6
eefe5b4
 
 
6facde6
 
 
 
 
 
 
 
 
eefe5b4
6facde6
 
 
e0d4d2f
eefe5b4
6facde6
3701938
 
 
 
 
6facde6
3701938
 
6facde6
fd55cab
 
3d6a9c7
eefe5b4
6facde6
eefe5b4
 
 
fd55cab
 
 
6facde6
3d6a9c7
 
 
 
 
 
 
6facde6
fd55cab
 
3d6a9c7
6facde6
 
 
 
 
 
 
 
 
 
fd55cab
6facde6
 
 
 
 
 
3d6a9c7
 
 
 
 
 
 
6facde6
 
3701938
 
 
 
6facde6
eefe5b4
6facde6
 
 
3701938
e0d4d2f
6facde6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
import gradio as gr
import torch
from PIL import Image
import cv2
import numpy as np
from transformers import CLIPProcessor, CLIPModel
from ultralytics import FastSAM
import supervision as sv
import os
import requests
from tqdm.auto import tqdm  # For a nice progress bar

# --- Constants and Model Initialization ---

# CLIP
CLIP_MODEL_NAME = "openai/clip-vit-base-patch32"

# FastSAM
FASTSAM_WEIGHTS_URL = "https://huggingface.co/spaces/An-619/FastSAM/resolve/main/weights/FastSAM-s.pt"
FASTSAM_WEIGHTS_NAME = "FastSAM-s.pt"

# Default FastSAM parameters
DEFAULT_IMGSZ = 640
DEFAULT_CONFIDENCE = 0.4
DEFAULT_IOU = 0.9
DEFAULT_RETINA_MASKS = False

# --- Helper Functions ---

def download_file(url, filename):
    """Downloads a file from a URL with a progress bar."""
    response = requests.get(url, stream=True)
    response.raise_for_status()  # Raise an exception for bad status codes

    total_size = int(response.headers.get('content-length', 0))
    block_size = 1024  # 1 KB
    progress_bar = tqdm(total=total_size, unit='iB', unit_scale=True)

    with open(filename, 'wb') as file:
        for data in response.iter_content(block_size):
            progress_bar.update(len(data))
            file.write(data)
    progress_bar.close()

    if total_size != 0 and progress_bar.n != total_size:
        raise ValueError("Error: Download failed.")

# --- Model Loading ---

# Load CLIP model (this part is correct in your original code)
model = CLIPModel.from_pretrained(CLIP_MODEL_NAME)
processor = CLIPProcessor.from_pretrained(CLIP_MODEL_NAME)

# Load FastSAM model with dynamic device handling
if not os.path.exists(FASTSAM_WEIGHTS_NAME):
    print(f"Downloading FastSAM weights from {FASTSAM_WEIGHTS_URL}...")
    try:
        download_file(FASTSAM_WEIGHTS_URL, FASTSAM_WEIGHTS_NAME)
        print("FastSAM weights downloaded successfully.")
    except Exception as e:
        print(f"Error downloading FastSAM weights: {e}")
        raise  # Re-raise the exception to stop execution

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
fast_sam = FastSAM(FASTSAM_WEIGHTS_NAME)
fast_sam.to(device)
print(f"FastSAM loaded on device: {device}")

# --- Processing Functions ---

def process_image_clip(image, text_input):
    # ... (Your CLIP processing function remains the same) ...
    if image is None:
        return "Please upload an image first."
    if not text_input:
        return "Please enter some text to check in the image."

    try:
        # Convert numpy array to PIL Image if needed
        if isinstance(image, np.ndarray):
            image = Image.fromarray(image)

        # Create a list of candidate labels
        candidate_labels = [text_input, f"not {text_input}"]

        # Process image and text
        inputs = processor(
            images=image,
            text=candidate_labels,
            return_tensors="pt",
            padding=True
        )

        # Get model predictions
        outputs = model(**{k: v for k, v in inputs.items()})
        logits_per_image = outputs.logits_per_image
        probs = logits_per_image.softmax(dim=1)

        # Get confidence for the positive label
        confidence = float(probs[0][0])
        return f"Confidence that the image contains '{text_input}': {confidence:.2%}"
    except Exception as e:
        return f"Error processing image: {str(e)}"

def process_image_fastsam(image, imgsz, conf, iou, retina_masks):
    if image is None:
        return None, "Please upload an image to segment."

    try:
        # Convert PIL image to numpy array if needed
        if isinstance(image, Image.Image):
            image_np = np.array(image)
        else:
            image_np = image

        # Run FastSAM inference
        results = fast_sam(image_np, device=device, retina_masks=retina_masks, imgsz=imgsz, conf=conf, iou=iou)

        # Check if results are valid
        if results is None or len(results) == 0 or results[0] is None:
          return None, "FastSAM did not return valid results. Try adjusting parameters or using a different image."
            
        # Get detections
        detections = sv.Detections.from_ultralytics(results[0])
          # Check if detections are valid
        if detections is None or len(detections) == 0:
          return None, "No objects detected in the image. Try lowering the confidence threshold."

        # Create annotator
        box_annotator = sv.BoxAnnotator()
        mask_annotator = sv.MaskAnnotator()

        # Annotate image
        annotated_image = mask_annotator.annotate(scene=image_np.copy(), detections=detections)
        annotated_image = box_annotator.annotate(scene=annotated_image, detections=detections)

        return Image.fromarray(annotated_image), None # Return None for the error message since there's no error

    except RuntimeError as re:
      if "out of memory" in str(re).lower():
          return None, "Error: Out of memory. Try reducing the image size (imgsz) or disabling retina masks."
      else:
          return None, f"Runtime error during FastSAM processing: {str(re)}"

    except Exception as e:
        return None, f"Error processing image with FastSAM: {str(e)}"

# --- Gradio Interface ---

with gr.Blocks(css="footer {visibility: hidden}") as demo:
    # ... (Your Markdown and CLIP tab remain mostly the same) ...
    gr.Markdown("""
    # CLIP and FastSAM Demo
    This demo combines two powerful AI models:
    - **CLIP**: For zero-shot image classification
    - **FastSAM**: For automatic image segmentation

    Try uploading an image and use either of the tabs below!
    """)

    with gr.Tab("CLIP Zero-Shot Classification"):
        with gr.Row():
            image_input = gr.Image(label="Input Image")
            text_input = gr.Textbox(
                label="What do you want to check in the image?",
                placeholder="e.g., 'a dog', 'sunset', 'people playing'",
                info="Enter any concept you want to check in the image"
            )
        output_text = gr.Textbox(label="Result")
        classify_btn = gr.Button("Classify")
        classify_btn.click(fn=process_image_clip, inputs=[image_input, text_input], outputs=output_text)

        gr.Examples(
            examples=[
                ["https://raw.githubusercontent.com/gradio-app/gradio/main/demo/kitchen/kitchen.png", "kitchen"],
                ["https://raw.githubusercontent.com/gradio-app/gradio/main/demo/calculator/calculator.jpg", "calculator"],
            ],
            inputs=[image_input, text_input],
        )

    with gr.Tab("FastSAM Segmentation"):
        with gr.Row():
            image_input_sam = gr.Image(label="Input Image")
            with gr.Column():
                imgsz_slider = gr.Slider(minimum=320, maximum=1920, step=32, value=DEFAULT_IMGSZ, label="Image Size (imgsz)")
                conf_slider = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=DEFAULT_CONFIDENCE, label="Confidence Threshold")
                iou_slider = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=DEFAULT_IOU, label="IoU Threshold")
                retina_checkbox = gr.Checkbox(label="Retina Masks", value=DEFAULT_RETINA_MASKS)
            
        with gr.Row():
          image_output = gr.Image(label="Segmentation Result")
          error_output = gr.Textbox(label="Error Message", type="text") # Added for displaying errors

        segment_btn = gr.Button("Segment")
        segment_btn.click(
            fn=process_image_fastsam,
            inputs=[image_input_sam, imgsz_slider, conf_slider, iou_slider, retina_checkbox],
            outputs=[image_output, error_output] # Output to both image and error textboxes
        )

        gr.Examples(
            examples=[
                ["https://raw.githubusercontent.com/gradio-app/gradio/main/demo/kitchen/kitchen.png"],
                ["https://raw.githubusercontent.com/gradio-app/gradio/main/demo/calculator/calculator.jpg"],
            ],
            inputs=[image_input_sam],
        )

    # ... (Your final Markdown remains the same) ...
    gr.Markdown("""
    ### How to use:
    1. **CLIP Classification**: Upload an image and enter text to check if that concept exists in the image
    2. **FastSAM Segmentation**: Upload an image to get automatic segmentation with bounding boxes and masks

    ### Note:
    - The models run on CPU by default, so processing might take a few seconds. If you have a GPU, it will be used automatically.
    - For best results, use clear images with good lighting.
    - You can adjust FastSAM parameters (Image Size, Confidence, IoU, Retina Masks) in the Segmentation tab.
    """)

demo.launch(share=True)