sagar007 commited on
Commit
b066832
·
verified ·
1 Parent(s): e31b682

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +237 -180
app.py CHANGED
@@ -1,220 +1,277 @@
1
  import gradio as gr
2
  import torch
3
- from PIL import Image
4
- import cv2
5
  import numpy as np
6
- from transformers import CLIPProcessor, CLIPModel
7
- from ultralytics import FastSAM
8
- import supervision as sv
9
  import os
10
- import requests
11
- from tqdm.auto import tqdm # For a nice progress bar
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
- # --- Constants and Model Initialization ---
14
 
15
- # CLIP
16
- CLIP_MODEL_NAME = "openai/clip-vit-base-patch32"
17
 
18
- # FastSAM
19
- # *Corrected* HuggingFace link for the weights
20
- FASTSAM_WEIGHTS_URL = "https://huggingface.co/spaces/An-619/FastSAM/resolve/6f76f474c656d2cb29599f49c296a8784b02d04b/weights/FastSAM-s.pt"
21
- FASTSAM_WEIGHTS_NAME = "FastSAM-s.pt"
 
 
22
 
23
- # Default FastSAM parameters
24
- DEFAULT_IMGSZ = 640
25
- DEFAULT_CONFIDENCE = 0.4
26
- DEFAULT_IOU = 0.9
27
- DEFAULT_RETINA_MASKS = False
28
 
29
- # --- Helper Functions ---
 
 
30
 
31
- def download_file(url, filename):
32
- """Downloads a file from a URL with a progress bar."""
33
- response = requests.get(url, stream=True)
34
- response.raise_for_status() # Raise an exception for bad status codes
35
 
36
- total_size = int(response.headers.get('content-length', 0))
37
- block_size = 1024 # 1 KB
38
- progress_bar = tqdm(total=total_size, unit='iB', unit_scale=True)
 
39
 
40
- with open(filename, 'wb') as file:
41
- for data in response.iter_content(block_size):
42
- progress_bar.update(len(data))
43
- file.write(data)
44
- progress_bar.close()
45
 
46
- if total_size != 0 and progress_bar.n != total_size:
47
- raise ValueError("Error: Download failed.")
 
 
48
 
49
- # --- Model Loading ---
50
 
51
- # Load CLIP model (this part is correct in your original code)
52
- model = CLIPModel.from_pretrained(CLIP_MODEL_NAME)
53
- processor = CLIPProcessor.from_pretrained(CLIP_MODEL_NAME)
54
 
55
- # Load FastSAM model with dynamic device handling
56
- if not os.path.exists(FASTSAM_WEIGHTS_NAME):
57
- print(f"Downloading FastSAM weights from {FASTSAM_WEIGHTS_URL}...")
58
- try:
59
- download_file(FASTSAM_WEIGHTS_URL, FASTSAM_WEIGHTS_NAME)
60
- print("FastSAM weights downloaded successfully.")
61
  except Exception as e:
62
- print(f"Error downloading FastSAM weights: {e}")
63
- raise # Re-raise the exception to stop execution
64
 
65
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
66
- fast_sam = FastSAM(FASTSAM_WEIGHTS_NAME)
67
- fast_sam.to(device)
68
- print(f"FastSAM loaded on device: {device}")
69
 
70
- # --- Processing Functions ---
 
 
 
 
 
 
 
71
 
72
- def process_image_clip(image, text_input):
73
- # ... (Your CLIP processing function remains the same) ...
74
- if image is None:
75
- return "Please upload an image first."
76
- if not text_input:
77
- return "Please enter some text to check in the image."
78
 
79
  try:
80
- # Convert numpy array to PIL Image if needed
81
- if isinstance(image, np.ndarray):
82
- image = Image.fromarray(image)
83
-
84
- # Create a list of candidate labels
85
- candidate_labels = [text_input, f"not {text_input}"]
86
-
87
- # Process image and text
88
- inputs = processor(
89
- images=image,
90
- text=candidate_labels,
91
- return_tensors="pt",
92
- padding=True
93
- )
94
-
95
- # Get model predictions
96
- outputs = model(**{k: v for k, v in inputs.items()})
97
- logits_per_image = outputs.logits_per_image
98
- probs = logits_per_image.softmax(dim=1)
99
 
100
- # Get confidence for the positive label
101
- confidence = float(probs[0][0])
102
- return f"Confidence that the image contains '{text_input}': {confidence:.2%}"
103
- except Exception as e:
104
- return f"Error processing image: {str(e)}"
105
 
106
- def process_image_fastsam(image, imgsz, conf, iou, retina_masks):
107
- if image is None:
108
- return None, "Please upload an image to segment."
 
 
 
 
 
 
 
109
 
110
- try:
111
- # Convert PIL image to numpy array if needed
112
- if isinstance(image, Image.Image):
113
- image_np = np.array(image)
114
- else:
115
- image_np = image
116
 
117
- # Run FastSAM inference
118
- results = fast_sam(image_np, device=device, retina_masks=retina_masks, imgsz=imgsz, conf=conf, iou=iou)
119
 
120
- # Check if results are valid
121
- if results is None or len(results) == 0 or results[0] is None:
122
- return None, "FastSAM did not return valid results. Try adjusting parameters or using a different image."
123
 
124
- # Get detections
125
- detections = sv.Detections.from_ultralytics(results[0])
126
- # Check if detections are valid
127
- if detections is None or len(detections) == 0:
128
- return None, "No objects detected in the image. Try lowering the confidence threshold."
 
 
 
129
 
130
- # Create annotator
131
- box_annotator = sv.BoxAnnotator()
132
- mask_annotator = sv.MaskAnnotator()
 
 
133
 
134
- # Annotate image
135
- annotated_image = mask_annotator.annotate(scene=image_np.copy(), detections=detections)
136
- annotated_image = box_annotator.annotate(scene=annotated_image, detections=detections)
 
 
137
 
138
- return Image.fromarray(annotated_image), None # Return None for the error message since there's no error
 
139
 
140
- except RuntimeError as re:
141
- if "out of memory" in str(re).lower():
142
- return None, "Error: Out of memory. Try reducing the image size (imgsz) or disabling retina masks."
143
- else:
144
- return None, f"Runtime error during FastSAM processing: {str(re)}"
145
 
146
  except Exception as e:
147
- return None, f"Error processing image with FastSAM: {str(e)}"
 
 
 
 
148
 
149
  # --- Gradio Interface ---
150
 
151
- with gr.Blocks(css="footer {visibility: hidden}") as demo:
152
- # ... (Your Markdown and CLIP tab remain mostly the same) ...
153
- gr.Markdown("""
154
- # CLIP and FastSAM Demo
155
- This demo combines two powerful AI models:
156
- - **CLIP**: For zero-shot image classification
157
- - **FastSAM**: For automatic image segmentation
158
- Try uploading an image and use either of the tabs below!
159
- """)
160
-
161
- with gr.Tab("CLIP Zero-Shot Classification"):
162
- with gr.Row():
163
- image_input = gr.Image(label="Input Image")
164
- text_input = gr.Textbox(
165
- label="What do you want to check in the image?",
166
- placeholder="e.g., 'a dog', 'sunset', 'people playing'",
167
- info="Enter any concept you want to check in the image"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
168
  )
169
- output_text = gr.Textbox(label="Result")
170
- classify_btn = gr.Button("Classify")
171
- classify_btn.click(fn=process_image_clip, inputs=[image_input, text_input], outputs=output_text)
172
-
173
- gr.Examples(
174
- examples=[
175
- ["https://raw.githubusercontent.com/gradio-app/gradio/main/demo/kitchen/kitchen.png", "kitchen"],
176
- ["https://raw.githubusercontent.com/gradio-app/gradio/main/demo/calculator/calculator.jpg", "calculator"],
177
- ],
178
- inputs=[image_input, text_input],
179
- )
180
-
181
- with gr.Tab("FastSAM Segmentation"):
182
- with gr.Row():
183
- image_input_sam = gr.Image(label="Input Image")
184
- with gr.Column():
185
- imgsz_slider = gr.Slider(minimum=320, maximum=1920, step=32, value=DEFAULT_IMGSZ, label="Image Size (imgsz)")
186
- conf_slider = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=DEFAULT_CONFIDENCE, label="Confidence Threshold")
187
- iou_slider = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=DEFAULT_IOU, label="IoU Threshold")
188
- retina_checkbox = gr.Checkbox(label="Retina Masks", value=DEFAULT_RETINA_MASKS)
189
-
190
- with gr.Row():
191
- image_output = gr.Image(label="Segmentation Result")
192
- error_output = gr.Textbox(label="Error Message", type="text") # Added for displaying errors
193
-
194
- segment_btn = gr.Button("Segment")
195
- segment_btn.click(
196
- fn=process_image_fastsam,
197
- inputs=[image_input_sam, imgsz_slider, conf_slider, iou_slider, retina_checkbox],
198
- outputs=[image_output, error_output] # Output to both image and error textboxes
199
- )
200
 
201
- gr.Examples(
202
- examples=[
203
- ["https://raw.githubusercontent.com/gradio-app/gradio/main/demo/kitchen/kitchen.png"],
204
- ["https://raw.githubusercontent.com/gradio-app/gradio/main/demo/calculator/calculator.jpg"],
205
- ],
206
- inputs=[image_input_sam],
207
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
 
209
- # ... (Your final Markdown remains the same) ...
210
- gr.Markdown("""
211
- ### How to use:
212
- 1. **CLIP Classification**: Upload an image and enter text to check if that concept exists in the image
213
- 2. **FastSAM Segmentation**: Upload an image to get automatic segmentation with bounding boxes and masks
214
- ### Note:
215
- - The models run on CPU by default, so processing might take a few seconds. If you have a GPU, it will be used automatically.
216
- - For best results, use clear images with good lighting.
217
- - You can adjust FastSAM parameters (Image Size, Confidence, IoU, Retina Masks) in the Segmentation tab.
218
- """)
219
-
220
- demo.launch(share=True)
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import torch
3
+ from transformers import AutoProcessor, AutoModel
4
+ from PIL import Image, ImageDraw, ImageFont
5
  import numpy as np
6
+ import random
 
 
7
  import os
8
+ import wget # To download weights
9
+
10
+ # --- Configuration & Model Loading ---
11
+
12
+ # Device Selection
13
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
14
+ print(f"Using device: {DEVICE}")
15
+
16
+ # --- CLIP Setup ---
17
+ CLIP_MODEL_ID = "openai/clip-vit-base-patch32"
18
+ clip_processor = None
19
+ clip_model = None
20
+
21
+ def load_clip_model():
22
+ global clip_processor, clip_model
23
+ if clip_processor is None:
24
+ print(f"Loading CLIP processor: {CLIP_MODEL_ID}...")
25
+ clip_processor = AutoProcessor.from_pretrained(CLIP_MODEL_ID)
26
+ print("CLIP processor loaded.")
27
+ if clip_model is None:
28
+ print(f"Loading CLIP model: {CLIP_MODEL_ID}...")
29
+ clip_model = AutoModel.from_pretrained(CLIP_MODEL_ID).to(DEVICE)
30
+ print(f"CLIP model loaded to {DEVICE}.")
31
+
32
+ # --- FastSAM Setup ---
33
+ # Use a smaller model suitable for Spaces CPU/basic GPU if needed
34
+ FASTSAM_CHECKPOINT = "FastSAM-s.pt"
35
+ FASTSAM_CHECKPOINT_URL = f"https://huggingface.co/spaces/An-619/FastSAM/resolve/main/{FASTSAM_CHECKPOINT}" # Example URL, find official if possible
36
+
37
+ fastsam_model = None
38
+
39
+ def download_fastsam_weights():
40
+ if not os.path.exists(FASTSAM_CHECKPOINT):
41
+ print(f"Downloading FastSAM weights: {FASTSAM_CHECKPOINT}...")
42
+ try:
43
+ wget.download(FASTSAM_CHECKPOINT_URL, FASTSAM_CHECKPOINT)
44
+ print("FastSAM weights downloaded.")
45
+ except Exception as e:
46
+ print(f"Error downloading FastSAM weights: {e}")
47
+ print("Please ensure the URL is correct and reachable, or manually place the weights file.")
48
+ return False
49
+ return os.path.exists(FASTSAM_CHECKPOINT)
50
+
51
+ def load_fastsam_model():
52
+ global fastsam_model
53
+ if fastsam_model is None:
54
+ if download_fastsam_weights():
55
+ try:
56
+ from fastsam import FastSAM, FastSAMPrompt # Import here after potential download
57
+ print(f"Loading FastSAM model: {FASTSAM_CHECKPOINT}...")
58
+ fastsam_model = FastSAM(FASTSAM_CHECKPOINT)
59
+ print(f"FastSAM model loaded.") # Device handled internally by FastSAM based on its setup/torch device
60
+ except ImportError:
61
+ print("Error: 'fastsam' library not found. Please install it (pip install fastsam).")
62
+ except Exception as e:
63
+ print(f"Error loading FastSAM model: {e}")
64
+ else:
65
+ print("FastSAM weights not found. Cannot load model.")
66
 
 
67
 
68
+ # --- Processing Functions ---
 
69
 
70
+ # CLIP Zero-Shot Classification Function
71
+ def run_clip_zero_shot(image: Image.Image, text_labels: str):
72
+ if clip_model is None or clip_processor is None:
73
+ load_clip_model() # Attempt to load if not already loaded
74
+ if clip_model is None:
75
+ return "Error: CLIP Model not loaded. Check logs.", None
76
 
77
+ if not text_labels:
78
+ return "Please provide comma-separated text labels.", None
79
+ if image is None:
80
+ return "Please upload an image.", None
 
81
 
82
+ labels = [label.strip() for label in text_labels.split(',')]
83
+ if not labels:
84
+ return "No valid labels provided.", None
85
 
86
+ print(f"Running CLIP zero-shot classification with labels: {labels}")
 
 
 
87
 
88
+ try:
89
+ # Ensure image is RGB
90
+ if image.mode != "RGB":
91
+ image = image.convert("RGB")
92
 
93
+ inputs = clip_processor(text=labels, images=image, return_tensors="pt", padding=True).to(DEVICE)
 
 
 
 
94
 
95
+ with torch.no_grad():
96
+ outputs = clip_model(**inputs)
97
+ logits_per_image = outputs.logits_per_image # this is the image-text similarity score
98
+ probs = logits_per_image.softmax(dim=1) # convert to probabilities
99
 
100
+ print("CLIP processing complete.")
101
 
102
+ # Format output for Gradio Label
103
+ confidences = {labels[i]: float(probs[0, i].item()) for i in range(len(labels))}
104
+ return confidences, image # Return original image for display alongside results
105
 
 
 
 
 
 
 
106
  except Exception as e:
107
+ print(f"Error during CLIP processing: {e}")
108
+ return f"An error occurred: {e}", None
109
 
 
 
 
 
110
 
111
+ # FastSAM Segmentation Function
112
+ def run_fastsam_segmentation(image_pil: Image.Image, conf_threshold: float = 0.4, iou_threshold: float = 0.9):
113
+ if fastsam_model is None:
114
+ load_fastsam_model() # Attempt to load if not already loaded
115
+ if fastsam_model is None:
116
+ return "Error: FastSAM Model not loaded. Check logs.", None
117
+ if image_pil is None:
118
+ return "Please upload an image.", None
119
 
120
+ print("Running FastSAM segmentation...")
 
 
 
 
 
121
 
122
  try:
123
+ # Ensure image is RGB
124
+ if image_pil.mode != "RGB":
125
+ image_pil = image_pil.convert("RGB")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
 
127
+ # FastSAM expects a BGR numpy array or path usually. Let's try with RGB numpy.
128
+ # If it fails, uncomment the BGR conversion line.
129
+ image_np_rgb = np.array(image_pil)
130
+ # image_np_bgr = image_np_rgb[:, :, ::-1] # Convert RGB to BGR if needed
 
131
 
132
+ # Run FastSAM inference
133
+ # Adjust imgsz, conf, iou as needed. Higher imgsz = more detail, slower.
134
+ everything_results = fastsam_model(
135
+ image_np_rgb, # Use image_np_bgr if conversion needed
136
+ device=DEVICE,
137
+ retina_masks=True,
138
+ imgsz=640, # Smaller size for faster inference on limited hardware
139
+ conf=conf_threshold,
140
+ iou=iou_threshold,
141
+ )
142
 
143
+ # Process results using FastSAMPrompt
144
+ from fastsam import FastSAMPrompt # Make sure it's imported
145
+ prompt_process = FastSAMPrompt(image_np_rgb, everything_results, device=DEVICE)
 
 
 
146
 
147
+ # Get all annotations (masks)
148
+ ann = prompt_process.everything_prompt()
149
 
150
+ print(f"FastSAM found {len(ann[0]['masks']) if ann and ann[0] else 0} masks.")
 
 
151
 
152
+ # --- Plotting Masks on Image (Manual) ---
153
+ output_image = image_pil.copy()
154
+ if ann and ann[0] is not None and 'masks' in ann[0] and len(ann[0]['masks']) > 0:
155
+ masks = ann[0]['masks'].cpu().numpy() # shape (N, H, W)
156
+
157
+ # Create overlay image
158
+ overlay = Image.new('RGBA', output_image.size, (0, 0, 0, 0))
159
+ draw = ImageDraw.Draw(overlay)
160
 
161
+ for i in range(masks.shape[0]):
162
+ mask = masks[i] # shape (H, W), boolean
163
+
164
+ # Generate random color with some transparency
165
+ color = (random.randint(50, 255), random.randint(50, 255), random.randint(50, 255), 128) # RGBA with alpha
166
 
167
+ # Create a single-channel image from the boolean mask
168
+ mask_image = Image.fromarray((mask * 255).astype(np.uint8), mode='L')
169
+
170
+ # Apply color to the mask area on the overlay
171
+ draw.bitmap((0,0), mask_image, fill=color)
172
 
173
+ # Composite the overlay onto the original image
174
+ output_image = Image.alpha_composite(output_image.convert('RGBA'), overlay).convert('RGB')
175
 
176
+ print("FastSAM processing and plotting complete.")
177
+ return output_image, image_pil # Return segmented and original images
 
 
 
178
 
179
  except Exception as e:
180
+ print(f"Error during FastSAM processing: {e}")
181
+ import traceback
182
+ traceback.print_exc() # Print detailed traceback
183
+ return f"An error occurred: {e}", None
184
+
185
 
186
  # --- Gradio Interface ---
187
 
188
+ # Pre-load models on startup (optional but good for performance)
189
+ print("Attempting to preload models...")
190
+ load_clip_model()
191
+ load_fastsam_model()
192
+ print("Preloading finished (or attempted).")
193
+
194
+
195
+ with gr.Blocks(theme=gr.themes.Soft()) as demo:
196
+ gr.Markdown("# CLIP & FastSAM Demo")
197
+ gr.Markdown("Explore Zero-Shot Classification with CLIP and 'Segment Anything' with FastSAM.")
198
+
199
+ with gr.Tabs():
200
+ # --- CLIP Tab ---
201
+ with gr.TabItem("CLIP Zero-Shot Classification"):
202
+ gr.Markdown("Upload an image and provide comma-separated candidate labels (e.g., 'cat, dog, car'). CLIP will predict the probability of the image matching each label.")
203
+ with gr.Row():
204
+ with gr.Column(scale=1):
205
+ clip_input_image = gr.Image(type="pil", label="Input Image")
206
+ clip_text_labels = gr.Textbox(label="Comma-Separated Labels", placeholder="e.g., astronaut, mountain, dog playing fetch")
207
+ clip_button = gr.Button("Run CLIP Classification", variant="primary")
208
+ with gr.Column(scale=1):
209
+ clip_output_label = gr.Label(label="Classification Probabilities")
210
+ clip_output_image_display = gr.Image(type="pil", label="Input Image Preview") # Show input for context
211
+
212
+ clip_button.click(
213
+ run_clip_zero_shot,
214
+ inputs=[clip_input_image, clip_text_labels],
215
+ outputs=[clip_output_label, clip_output_image_display]
216
+ )
217
+ gr.Examples(
218
+ examples=[
219
+ ["examples/astronaut.jpg", "astronaut, moon, rover, mountain"],
220
+ ["examples/dog_bike.jpg", "dog, bicycle, person, park, grass"],
221
+ ],
222
+ inputs=[clip_input_image, clip_text_labels],
223
+ outputs=[clip_output_label, clip_output_image_display],
224
+ fn=run_clip_zero_shot,
225
+ cache_examples=False, # Re-run for live demo
226
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227
 
228
+ # --- FastSAM Tab ---
229
+ with gr.TabItem("FastSAM Segmentation"):
230
+ gr.Markdown("Upload an image. FastSAM will attempt to segment all objects/regions in the image.")
231
+ with gr.Row():
232
+ with gr.Column(scale=1):
233
+ fastsam_input_image = gr.Image(type="pil", label="Input Image")
234
+ with gr.Row():
235
+ fastsam_conf = gr.Slider(minimum=0.1, maximum=1.0, value=0.4, step=0.05, label="Confidence Threshold")
236
+ fastsam_iou = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="IoU Threshold")
237
+ fastsam_button = gr.Button("Run FastSAM Segmentation", variant="primary")
238
+ with gr.Column(scale=1):
239
+ fastsam_output_image = gr.Image(type="pil", label="Segmented Image")
240
+ # fastsam_input_display = gr.Image(type="pil", label="Original Image") # Optional: show original side-by-side
241
+
242
+ fastsam_button.click(
243
+ run_fastsam_segmentation,
244
+ inputs=[fastsam_input_image, fastsam_conf, fastsam_iou],
245
+ outputs=[fastsam_output_image] # Removed the second output for simplicity, adjust if needed
246
+ )
247
+ gr.Examples(
248
+ examples=[
249
+ ["examples/dogs.jpg", 0.4, 0.9],
250
+ ["examples/fruits.jpg", 0.5, 0.8],
251
+ ],
252
+ inputs=[fastsam_input_image, fastsam_conf, fastsam_iou],
253
+ outputs=[fastsam_output_image],
254
+ fn=run_fastsam_segmentation,
255
+ cache_examples=False, # Re-run for live demo
256
+ )
257
 
258
+ # Add example images (optional, but helpful)
259
+ # Create an 'examples' folder and add some jpg images like 'astronaut.jpg', 'dog_bike.jpg', 'dogs.jpg', 'fruits.jpg'
260
+ if not os.path.exists("examples"):
261
+ os.makedirs("examples")
262
+ print("Created 'examples' directory. Please add some images (e.g., astronaut.jpg, dog_bike.jpg) for the examples to work.")
263
+ # You might need to download some sample images here too if running on a fresh env
264
+ try:
265
+ print("Downloading example images...")
266
+ wget.download("https://huggingface.co/spaces/gradio/image-segmentation/resolve/main/images/lion.jpg", "examples/lion.jpg")
267
+ wget.download("https://raw.githubusercontent.com/openai/CLIP/main/CLIP.png", "examples/clip_logo.png")
268
+ wget.download("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/gradio-logo.png", "examples/gradio_logo.png")
269
+ # Manually add the examples used above if these don't match
270
+ print("Example images downloaded (or attempted). Please verify.")
271
+ except Exception as e:
272
+ print(f"Could not download example images: {e}")
273
+
274
+
275
+ # Launch the Gradio app
276
+ if __name__ == "__main__":
277
+ demo.launch(debug=True) # Set debug=False for deployment