sagar007 commited on
Commit
eba2946
·
verified ·
1 Parent(s): 2cfae42

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +110 -70
app.py CHANGED
@@ -6,6 +6,7 @@ import numpy as np
6
  import random
7
  import os
8
  import wget # To download weights
 
9
 
10
  # --- Configuration & Model Loading ---
11
 
@@ -30,39 +31,68 @@ def load_clip_model():
30
  print(f"CLIP model loaded to {DEVICE}.")
31
 
32
  # --- FastSAM Setup ---
33
- # Use a smaller model suitable for Spaces CPU/basic GPU if needed
34
  FASTSAM_CHECKPOINT = "FastSAM-s.pt"
35
- FASTSAM_CHECKPOINT_URL = f"https://huggingface.co/spaces/An-619/FastSAM/resolve/main/{FASTSAM_CHECKPOINT}" # Example URL, find official if possible
 
36
 
37
  fastsam_model = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
  def download_fastsam_weights():
40
  if not os.path.exists(FASTSAM_CHECKPOINT):
41
- print(f"Downloading FastSAM weights: {FASTSAM_CHECKPOINT}...")
42
  try:
43
  wget.download(FASTSAM_CHECKPOINT_URL, FASTSAM_CHECKPOINT)
44
  print("FastSAM weights downloaded.")
45
  except Exception as e:
46
  print(f"Error downloading FastSAM weights: {e}")
47
  print("Please ensure the URL is correct and reachable, or manually place the weights file.")
 
 
 
 
 
 
48
  return False
49
  return os.path.exists(FASTSAM_CHECKPOINT)
50
 
51
  def load_fastsam_model():
52
  global fastsam_model
53
  if fastsam_model is None:
54
- if download_fastsam_weights():
 
 
 
 
55
  try:
56
- from fastsam import FastSAM, FastSAMPrompt # Import here after potential download
57
  print(f"Loading FastSAM model: {FASTSAM_CHECKPOINT}...")
58
  fastsam_model = FastSAM(FASTSAM_CHECKPOINT)
59
- print(f"FastSAM model loaded.") # Device handled internally by FastSAM based on its setup/torch device
60
- except ImportError:
61
- print("Error: 'fastsam' library not found. Please install it (pip install fastsam).")
62
  except Exception as e:
63
  print(f"Error loading FastSAM model: {e}")
 
64
  else:
65
- print("FastSAM weights not found. Cannot load model.")
66
 
67
 
68
  # --- Processing Functions ---
@@ -74,14 +104,16 @@ def run_clip_zero_shot(image: Image.Image, text_labels: str):
74
  if clip_model is None:
75
  return "Error: CLIP Model not loaded. Check logs.", None
76
 
77
- if not text_labels:
78
- return "Please provide comma-separated text labels.", None
79
  if image is None:
80
- return "Please upload an image.", None
 
 
 
81
 
82
- labels = [label.strip() for label in text_labels.split(',')]
83
  if not labels:
84
- return "No valid labels provided.", None
 
85
 
86
  print(f"Running CLIP zero-shot classification with labels: {labels}")
87
 
@@ -94,28 +126,36 @@ def run_clip_zero_shot(image: Image.Image, text_labels: str):
94
 
95
  with torch.no_grad():
96
  outputs = clip_model(**inputs)
97
- logits_per_image = outputs.logits_per_image # this is the image-text similarity score
98
- probs = logits_per_image.softmax(dim=1) # convert to probabilities
99
 
100
  print("CLIP processing complete.")
101
 
102
- # Format output for Gradio Label
103
  confidences = {labels[i]: float(probs[0, i].item()) for i in range(len(labels))}
104
- return confidences, image # Return original image for display alongside results
 
105
 
106
  except Exception as e:
107
  print(f"Error during CLIP processing: {e}")
108
- return f"An error occurred: {e}", None
 
 
109
 
110
 
111
  # FastSAM Segmentation Function
112
  def run_fastsam_segmentation(image_pil: Image.Image, conf_threshold: float = 0.4, iou_threshold: float = 0.9):
 
113
  if fastsam_model is None:
114
- load_fastsam_model() # Attempt to load if not already loaded
115
  if fastsam_model is None:
116
- return "Error: FastSAM Model not loaded. Check logs.", None
 
 
 
 
 
117
  if image_pil is None:
118
- return "Please upload an image.", None
119
 
120
  print("Running FastSAM segmentation...")
121
 
@@ -124,63 +164,52 @@ def run_fastsam_segmentation(image_pil: Image.Image, conf_threshold: float = 0.4
124
  if image_pil.mode != "RGB":
125
  image_pil = image_pil.convert("RGB")
126
 
127
- # FastSAM expects a BGR numpy array or path usually. Let's try with RGB numpy.
128
- # If it fails, uncomment the BGR conversion line.
129
  image_np_rgb = np.array(image_pil)
130
- # image_np_bgr = image_np_rgb[:, :, ::-1] # Convert RGB to BGR if needed
131
 
132
  # Run FastSAM inference
133
- # Adjust imgsz, conf, iou as needed. Higher imgsz = more detail, slower.
134
  everything_results = fastsam_model(
135
- image_np_rgb, # Use image_np_bgr if conversion needed
136
  device=DEVICE,
137
  retina_masks=True,
138
- imgsz=640, # Smaller size for faster inference on limited hardware
139
  conf=conf_threshold,
140
  iou=iou_threshold,
141
  )
142
 
143
- # Process results using FastSAMPrompt
144
- from fastsam import FastSAMPrompt # Make sure it's imported
145
  prompt_process = FastSAMPrompt(image_np_rgb, everything_results, device=DEVICE)
146
-
147
- # Get all annotations (masks)
148
  ann = prompt_process.everything_prompt()
149
 
150
- print(f"FastSAM found {len(ann[0]['masks']) if ann and ann[0] else 0} masks.")
151
 
152
- # --- Plotting Masks on Image (Manual) ---
153
  output_image = image_pil.copy()
154
  if ann and ann[0] is not None and 'masks' in ann[0] and len(ann[0]['masks']) > 0:
155
- masks = ann[0]['masks'].cpu().numpy() # shape (N, H, W)
156
-
157
- # Create overlay image
158
  overlay = Image.new('RGBA', output_image.size, (0, 0, 0, 0))
159
  draw = ImageDraw.Draw(overlay)
160
 
161
  for i in range(masks.shape[0]):
162
- mask = masks[i] # shape (H, W), boolean
163
-
164
- # Generate random color with some transparency
165
- color = (random.randint(50, 255), random.randint(50, 255), random.randint(50, 255), 128) # RGBA with alpha
166
-
167
- # Create a single-channel image from the boolean mask
168
  mask_image = Image.fromarray((mask * 255).astype(np.uint8), mode='L')
169
-
170
- # Apply color to the mask area on the overlay
171
  draw.bitmap((0,0), mask_image, fill=color)
172
 
173
- # Composite the overlay onto the original image
174
  output_image = Image.alpha_composite(output_image.convert('RGBA'), overlay).convert('RGB')
175
 
176
  print("FastSAM processing and plotting complete.")
177
- return output_image, image_pil # Return segmented and original images
 
178
 
 
 
 
 
179
  except Exception as e:
180
  print(f"Error during FastSAM processing: {e}")
181
- import traceback
182
- traceback.print_exc() # Print detailed traceback
183
- return f"An error occurred: {e}", None
184
 
185
 
186
  # --- Gradio Interface ---
@@ -188,7 +217,7 @@ def run_fastsam_segmentation(image_pil: Image.Image, conf_threshold: float = 0.4
188
  # Pre-load models on startup (optional but good for performance)
189
  print("Attempting to preload models...")
190
  load_clip_model()
191
- load_fastsam_model()
192
  print("Preloading finished (or attempted).")
193
 
194
 
@@ -203,11 +232,11 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
203
  with gr.Row():
204
  with gr.Column(scale=1):
205
  clip_input_image = gr.Image(type="pil", label="Input Image")
206
- clip_text_labels = gr.Textbox(label="Comma-Separated Labels", placeholder="e.g., astronaut, mountain, dog playing fetch")
207
  clip_button = gr.Button("Run CLIP Classification", variant="primary")
208
  with gr.Column(scale=1):
209
  clip_output_label = gr.Label(label="Classification Probabilities")
210
- clip_output_image_display = gr.Image(type="pil", label="Input Image Preview") # Show input for context
211
 
212
  clip_button.click(
213
  run_clip_zero_shot,
@@ -218,11 +247,12 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
218
  examples=[
219
  ["examples/astronaut.jpg", "astronaut, moon, rover, mountain"],
220
  ["examples/dog_bike.jpg", "dog, bicycle, person, park, grass"],
 
221
  ],
222
  inputs=[clip_input_image, clip_text_labels],
223
  outputs=[clip_output_label, clip_output_image_display],
224
  fn=run_clip_zero_shot,
225
- cache_examples=False, # Re-run for live demo
226
  )
227
 
228
  # --- FastSAM Tab ---
@@ -237,41 +267,51 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
237
  fastsam_button = gr.Button("Run FastSAM Segmentation", variant="primary")
238
  with gr.Column(scale=1):
239
  fastsam_output_image = gr.Image(type="pil", label="Segmented Image")
240
- # fastsam_input_display = gr.Image(type="pil", label="Original Image") # Optional: show original side-by-side
241
 
242
  fastsam_button.click(
243
  run_fastsam_segmentation,
244
  inputs=[fastsam_input_image, fastsam_conf, fastsam_iou],
245
- outputs=[fastsam_output_image] # Removed the second output for simplicity, adjust if needed
 
246
  )
247
  gr.Examples(
248
  examples=[
249
  ["examples/dogs.jpg", 0.4, 0.9],
250
  ["examples/fruits.jpg", 0.5, 0.8],
 
251
  ],
252
  inputs=[fastsam_input_image, fastsam_conf, fastsam_iou],
253
  outputs=[fastsam_output_image],
254
  fn=run_fastsam_segmentation,
255
- cache_examples=False, # Re-run for live demo
256
  )
257
 
258
  # Add example images (optional, but helpful)
259
- # Create an 'examples' folder and add some jpg images like 'astronaut.jpg', 'dog_bike.jpg', 'dogs.jpg', 'fruits.jpg'
260
  if not os.path.exists("examples"):
261
  os.makedirs("examples")
262
- print("Created 'examples' directory. Please add some images (e.g., astronaut.jpg, dog_bike.jpg) for the examples to work.")
263
- # You might need to download some sample images here too if running on a fresh env
264
- try:
265
- print("Downloading example images...")
266
- wget.download("https://huggingface.co/spaces/gradio/image-segmentation/resolve/main/images/lion.jpg", "examples/lion.jpg")
267
- wget.download("https://raw.githubusercontent.com/openai/CLIP/main/CLIP.png", "examples/clip_logo.png")
268
- wget.download("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/gradio-logo.png", "examples/gradio_logo.png")
269
- # Manually add the examples used above if these don't match
270
- print("Example images downloaded (or attempted). Please verify.")
271
- except Exception as e:
272
- print(f"Could not download example images: {e}")
 
 
 
 
 
 
 
273
 
274
 
275
  # Launch the Gradio app
276
  if __name__ == "__main__":
277
- demo.launch(debug=True) # Set debug=False for deployment
 
 
 
 
6
  import random
7
  import os
8
  import wget # To download weights
9
+ import traceback # For detailed error printing
10
 
11
  # --- Configuration & Model Loading ---
12
 
 
31
  print(f"CLIP model loaded to {DEVICE}.")
32
 
33
  # --- FastSAM Setup ---
 
34
  FASTSAM_CHECKPOINT = "FastSAM-s.pt"
35
+ # Use the official model hub repo URL
36
+ FASTSAM_CHECKPOINT_URL = f"https://huggingface.co/CASIA-IVA-Lab/FastSAM-s/resolve/main/{FASTSAM_CHECKPOINT}"
37
 
38
  fastsam_model = None
39
+ fastsam_lib_imported = False # Flag to check if import worked
40
+
41
+ def check_and_import_fastsam():
42
+ global fastsam_lib_imported
43
+ if not fastsam_lib_imported:
44
+ try:
45
+ from fastsam import FastSAM, FastSAMPrompt
46
+ globals()['FastSAM'] = FastSAM # Make classes available globally
47
+ globals()['FastSAMPrompt'] = FastSAMPrompt
48
+ fastsam_lib_imported = True
49
+ print("fastsam library imported successfully.")
50
+ except ImportError:
51
+ print("Error: 'fastsam' library not found or import failed.")
52
+ print("Please ensure 'fastsam' is installed correctly (pip install fastsam).")
53
+ fastsam_lib_imported = False
54
+ except Exception as e:
55
+ print(f"An unexpected error occurred during fastsam import: {e}")
56
+ fastsam_lib_imported = False
57
+ return fastsam_lib_imported
58
+
59
 
60
  def download_fastsam_weights():
61
  if not os.path.exists(FASTSAM_CHECKPOINT):
62
+ print(f"Downloading FastSAM weights: {FASTSAM_CHECKPOINT} from {FASTSAM_CHECKPOINT_URL}...")
63
  try:
64
  wget.download(FASTSAM_CHECKPOINT_URL, FASTSAM_CHECKPOINT)
65
  print("FastSAM weights downloaded.")
66
  except Exception as e:
67
  print(f"Error downloading FastSAM weights: {e}")
68
  print("Please ensure the URL is correct and reachable, or manually place the weights file.")
69
+ # Attempt to remove partially downloaded file if exists
70
+ if os.path.exists(FASTSAM_CHECKPOINT):
71
+ try:
72
+ os.remove(FASTSAM_CHECKPOINT)
73
+ except OSError:
74
+ pass # Ignore removal errors
75
  return False
76
  return os.path.exists(FASTSAM_CHECKPOINT)
77
 
78
  def load_fastsam_model():
79
  global fastsam_model
80
  if fastsam_model is None:
81
+ if not check_and_import_fastsam(): # Check import first
82
+ print("Cannot load FastSAM model because the library couldn't be imported.")
83
+ return # Exit if import failed
84
+
85
+ if download_fastsam_weights(): # Check download/existence second
86
  try:
87
+ # FastSAM class should be available via globals() now
88
  print(f"Loading FastSAM model: {FASTSAM_CHECKPOINT}...")
89
  fastsam_model = FastSAM(FASTSAM_CHECKPOINT)
90
+ print(f"FastSAM model loaded.") # Device handled internally by FastSAM
 
 
91
  except Exception as e:
92
  print(f"Error loading FastSAM model: {e}")
93
+ traceback.print_exc()
94
  else:
95
+ print("FastSAM weights not found or download failed. Cannot load model.")
96
 
97
 
98
  # --- Processing Functions ---
 
104
  if clip_model is None:
105
  return "Error: CLIP Model not loaded. Check logs.", None
106
 
 
 
107
  if image is None:
108
+ return "Please upload an image.", None # Return None for the image display
109
+ if not text_labels:
110
+ # Return empty results but display the uploaded image
111
+ return {}, image
112
 
113
+ labels = [label.strip() for label in text_labels.split(',') if label.strip()] # Ensure non-empty labels
114
  if not labels:
115
+ # Return empty results but display the uploaded image
116
+ return {}, image
117
 
118
  print(f"Running CLIP zero-shot classification with labels: {labels}")
119
 
 
126
 
127
  with torch.no_grad():
128
  outputs = clip_model(**inputs)
129
+ logits_per_image = outputs.logits_per_image
130
+ probs = logits_per_image.softmax(dim=1)
131
 
132
  print("CLIP processing complete.")
133
 
 
134
  confidences = {labels[i]: float(probs[0, i].item()) for i in range(len(labels))}
135
+ # Return results and the original image used for prediction
136
+ return confidences, image
137
 
138
  except Exception as e:
139
  print(f"Error during CLIP processing: {e}")
140
+ traceback.print_exc()
141
+ # Return error message and the original image
142
+ return f"An error occurred during CLIP: {e}", image
143
 
144
 
145
  # FastSAM Segmentation Function
146
  def run_fastsam_segmentation(image_pil: Image.Image, conf_threshold: float = 0.4, iou_threshold: float = 0.9):
147
+ # Ensure model is loaded or attempt to load
148
  if fastsam_model is None:
149
+ load_fastsam_model()
150
  if fastsam_model is None:
151
+ # Return error message string for the image component (Gradio handles this)
152
+ return "Error: FastSAM Model not loaded. Check logs."
153
+ # Ensure library was imported
154
+ if not fastsam_lib_imported:
155
+ return "Error: FastSAM library not available. Cannot run segmentation."
156
+
157
  if image_pil is None:
158
+ return "Please upload an image."
159
 
160
  print("Running FastSAM segmentation...")
161
 
 
164
  if image_pil.mode != "RGB":
165
  image_pil = image_pil.convert("RGB")
166
 
 
 
167
  image_np_rgb = np.array(image_pil)
 
168
 
169
  # Run FastSAM inference
 
170
  everything_results = fastsam_model(
171
+ image_np_rgb,
172
  device=DEVICE,
173
  retina_masks=True,
174
+ imgsz=640,
175
  conf=conf_threshold,
176
  iou=iou_threshold,
177
  )
178
 
179
+ # FastSAMPrompt should be available via globals() if import succeeded
 
180
  prompt_process = FastSAMPrompt(image_np_rgb, everything_results, device=DEVICE)
 
 
181
  ann = prompt_process.everything_prompt()
182
 
183
+ print(f"FastSAM found {len(ann[0]['masks']) if ann and ann[0] and 'masks' in ann[0] else 0} masks.")
184
 
185
+ # --- Plotting Masks on Image ---
186
  output_image = image_pil.copy()
187
  if ann and ann[0] is not None and 'masks' in ann[0] and len(ann[0]['masks']) > 0:
188
+ masks = ann[0]['masks'].cpu().numpy() # (N, H, W) boolean
189
+
 
190
  overlay = Image.new('RGBA', output_image.size, (0, 0, 0, 0))
191
  draw = ImageDraw.Draw(overlay)
192
 
193
  for i in range(masks.shape[0]):
194
+ mask = masks[i]
195
+ color = (random.randint(50, 255), random.randint(50, 255), random.randint(50, 255), 128) # RGBA
 
 
 
 
196
  mask_image = Image.fromarray((mask * 255).astype(np.uint8), mode='L')
 
 
197
  draw.bitmap((0,0), mask_image, fill=color)
198
 
 
199
  output_image = Image.alpha_composite(output_image.convert('RGBA'), overlay).convert('RGB')
200
 
201
  print("FastSAM processing and plotting complete.")
202
+ # *** FIX: Return ONLY the output image for the single Image component ***
203
+ return output_image
204
 
205
+ except NameError as ne:
206
+ print(f"NameError during FastSAM processing: {ne}. Was the fastsam library imported correctly?")
207
+ traceback.print_exc()
208
+ return f"A NameError occurred: {ne}. Check library import."
209
  except Exception as e:
210
  print(f"Error during FastSAM processing: {e}")
211
+ traceback.print_exc()
212
+ return f"An error occurred during FastSAM: {e}"
 
213
 
214
 
215
  # --- Gradio Interface ---
 
217
  # Pre-load models on startup (optional but good for performance)
218
  print("Attempting to preload models...")
219
  load_clip_model()
220
+ load_fastsam_model() # This will now also attempt download/check import
221
  print("Preloading finished (or attempted).")
222
 
223
 
 
232
  with gr.Row():
233
  with gr.Column(scale=1):
234
  clip_input_image = gr.Image(type="pil", label="Input Image")
235
+ clip_text_labels = gr.Textbox(label="Comma-Separated Labels", placeholder="e.g., astronaut, moon, dog playing fetch")
236
  clip_button = gr.Button("Run CLIP Classification", variant="primary")
237
  with gr.Column(scale=1):
238
  clip_output_label = gr.Label(label="Classification Probabilities")
239
+ clip_output_image_display = gr.Image(type="pil", label="Input Image Preview")
240
 
241
  clip_button.click(
242
  run_clip_zero_shot,
 
247
  examples=[
248
  ["examples/astronaut.jpg", "astronaut, moon, rover, mountain"],
249
  ["examples/dog_bike.jpg", "dog, bicycle, person, park, grass"],
250
+ ["examples/clip_logo.png", "logo, text, graphics, abstract art"], # Added another example
251
  ],
252
  inputs=[clip_input_image, clip_text_labels],
253
  outputs=[clip_output_label, clip_output_image_display],
254
  fn=run_clip_zero_shot,
255
+ cache_examples=False,
256
  )
257
 
258
  # --- FastSAM Tab ---
 
267
  fastsam_button = gr.Button("Run FastSAM Segmentation", variant="primary")
268
  with gr.Column(scale=1):
269
  fastsam_output_image = gr.Image(type="pil", label="Segmented Image")
 
270
 
271
  fastsam_button.click(
272
  run_fastsam_segmentation,
273
  inputs=[fastsam_input_image, fastsam_conf, fastsam_iou],
274
+ # Output is now correctly mapped to the single component
275
+ outputs=[fastsam_output_image]
276
  )
277
  gr.Examples(
278
  examples=[
279
  ["examples/dogs.jpg", 0.4, 0.9],
280
  ["examples/fruits.jpg", 0.5, 0.8],
281
+ ["examples/lion.jpg", 0.45, 0.9], # Added another example
282
  ],
283
  inputs=[fastsam_input_image, fastsam_conf, fastsam_iou],
284
  outputs=[fastsam_output_image],
285
  fn=run_fastsam_segmentation,
286
+ cache_examples=False,
287
  )
288
 
289
  # Add example images (optional, but helpful)
 
290
  if not os.path.exists("examples"):
291
  os.makedirs("examples")
292
+ print("Created 'examples' directory. Attempting to download sample images...")
293
+ example_files = {
294
+ "astronaut.jpg": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/d astronaut_-_St._Jean_Bay.jpg/640px-Astronaut_-_St._Jean_Bay.jpg", # Find suitable public domain/CC image
295
+ "dog_bike.jpg": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/gradio/outputs_multimodal.jpg", # Using a relevant example from HF
296
+ "clip_logo.png": "https://raw.githubusercontent.com/openai/CLIP/main/CLIP.png",
297
+ "dogs.jpg": "https://raw.githubusercontent.com/ultralytics/assets/main/im/image8.jpg", # From Ultralytics assets
298
+ "fruits.jpg": "https://raw.githubusercontent.com/ultralytics/assets/main/im/image9.jpg", # From Ultralytics assets
299
+ "lion.jpg": "https://huggingface.co/spaces/gradio/image-segmentation/resolve/main/images/lion.jpg"
300
+ }
301
+ for filename, url in example_files.items():
302
+ filepath = os.path.join("examples", filename)
303
+ if not os.path.exists(filepath):
304
+ try:
305
+ print(f"Downloading {filename}...")
306
+ wget.download(url, filepath)
307
+ except Exception as e:
308
+ print(f"Could not download {filename} from {url}: {e}")
309
+ print("Example image download attempt finished.")
310
 
311
 
312
  # Launch the Gradio app
313
  if __name__ == "__main__":
314
+ # share=True is primarily for local testing to get a public link.
315
+ # Not needed/used when deploying on Hugging Face Spaces.
316
+ # debug=True is helpful for development. Set to False for production.
317
+ demo.launch(debug=True)