sagar007 commited on
Commit
6facde6
·
verified ·
1 Parent(s): 3d6a9c7

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +120 -38
app.py CHANGED
@@ -7,31 +7,82 @@ from transformers import CLIPProcessor, CLIPModel
7
  from ultralytics import FastSAM
8
  import supervision as sv
9
  import os
 
 
10
 
11
- # Load CLIP model
12
- model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
13
- processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
14
 
15
- # Initialize FastSAM model
16
- FASTSAM_WEIGHTS = "FastSAM-s.pt"
17
- if not os.path.exists(FASTSAM_WEIGHTS):
18
- os.system(f"wget https://huggingface.co/spaces/An-619/FastSAM/resolve/main/weights/{FASTSAM_WEIGHTS}")
19
- fast_sam = FastSAM(FASTSAM_WEIGHTS)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
  def process_image_clip(image, text_input):
 
22
  if image is None:
23
  return "Please upload an image first."
24
  if not text_input:
25
  return "Please enter some text to check in the image."
26
-
27
  try:
28
  # Convert numpy array to PIL Image if needed
29
  if isinstance(image, np.ndarray):
30
  image = Image.fromarray(image)
31
-
32
  # Create a list of candidate labels
33
  candidate_labels = [text_input, f"not {text_input}"]
34
-
35
  # Process image and text
36
  inputs = processor(
37
  images=image,
@@ -39,70 +90,86 @@ def process_image_clip(image, text_input):
39
  return_tensors="pt",
40
  padding=True
41
  )
42
-
43
  # Get model predictions
44
  outputs = model(**{k: v for k, v in inputs.items()})
45
  logits_per_image = outputs.logits_per_image
46
  probs = logits_per_image.softmax(dim=1)
47
-
48
  # Get confidence for the positive label
49
  confidence = float(probs[0][0])
50
  return f"Confidence that the image contains '{text_input}': {confidence:.2%}"
51
  except Exception as e:
52
  return f"Error processing image: {str(e)}"
53
 
54
- def process_image_fastsam(image):
55
  if image is None:
56
- return None
57
-
58
  try:
59
  # Convert PIL image to numpy array if needed
60
  if isinstance(image, Image.Image):
61
  image_np = np.array(image)
62
  else:
63
  image_np = image
64
-
65
  # Run FastSAM inference
66
- results = fast_sam(image_np, device='cpu', retina_masks=True, imgsz=1024, conf=0.4, iou=0.9)
67
-
 
 
 
 
68
  # Get detections
69
  detections = sv.Detections.from_ultralytics(results[0])
70
-
 
 
 
71
  # Create annotator
72
  box_annotator = sv.BoxAnnotator()
73
  mask_annotator = sv.MaskAnnotator()
74
-
75
  # Annotate image
76
  annotated_image = mask_annotator.annotate(scene=image_np.copy(), detections=detections)
77
  annotated_image = box_annotator.annotate(scene=annotated_image, detections=detections)
78
-
79
- return Image.fromarray(annotated_image)
 
 
 
 
 
 
 
80
  except Exception as e:
81
- return f"Error processing image: {str(e)}"
 
 
82
 
83
- # Create Gradio interface
84
  with gr.Blocks(css="footer {visibility: hidden}") as demo:
 
85
  gr.Markdown("""
86
  # CLIP and FastSAM Demo
87
  This demo combines two powerful AI models:
88
  - **CLIP**: For zero-shot image classification
89
  - **FastSAM**: For automatic image segmentation
90
-
91
  Try uploading an image and use either of the tabs below!
92
  """)
93
-
94
  with gr.Tab("CLIP Zero-Shot Classification"):
95
  with gr.Row():
96
  image_input = gr.Image(label="Input Image")
97
  text_input = gr.Textbox(
98
- label="What do you want to check in the image?",
99
  placeholder="e.g., 'a dog', 'sunset', 'people playing'",
100
  info="Enter any concept you want to check in the image"
101
  )
102
  output_text = gr.Textbox(label="Result")
103
  classify_btn = gr.Button("Classify")
104
  classify_btn.click(fn=process_image_clip, inputs=[image_input, text_input], outputs=output_text)
105
-
106
  gr.Examples(
107
  examples=[
108
  ["https://raw.githubusercontent.com/gradio-app/gradio/main/demo/kitchen/kitchen.png", "kitchen"],
@@ -110,14 +177,27 @@ with gr.Blocks(css="footer {visibility: hidden}") as demo:
110
  ],
111
  inputs=[image_input, text_input],
112
  )
113
-
114
  with gr.Tab("FastSAM Segmentation"):
115
  with gr.Row():
116
  image_input_sam = gr.Image(label="Input Image")
117
- image_output = gr.Image(label="Segmentation Result")
 
 
 
 
 
 
 
 
 
118
  segment_btn = gr.Button("Segment")
119
- segment_btn.click(fn=process_image_fastsam, inputs=[image_input_sam], outputs=image_output)
120
-
 
 
 
 
121
  gr.Examples(
122
  examples=[
123
  ["https://raw.githubusercontent.com/gradio-app/gradio/main/demo/kitchen/kitchen.png"],
@@ -125,15 +205,17 @@ with gr.Blocks(css="footer {visibility: hidden}") as demo:
125
  ],
126
  inputs=[image_input_sam],
127
  )
128
-
 
129
  gr.Markdown("""
130
  ### How to use:
131
  1. **CLIP Classification**: Upload an image and enter text to check if that concept exists in the image
132
  2. **FastSAM Segmentation**: Upload an image to get automatic segmentation with bounding boxes and masks
133
-
134
  ### Note:
135
- - The models run on CPU, so processing might take a few seconds
136
- - For best results, use clear images with good lighting
 
137
  """)
138
 
139
- demo.launch(share=True)
 
7
  from ultralytics import FastSAM
8
  import supervision as sv
9
  import os
10
+ import requests
11
+ from tqdm.auto import tqdm # For a nice progress bar
12
 
13
+ # --- Constants and Model Initialization ---
 
 
14
 
15
+ # CLIP
16
+ CLIP_MODEL_NAME = "openai/clip-vit-base-patch32"
17
+
18
+ # FastSAM
19
+ FASTSAM_WEIGHTS_URL = "https://huggingface.co/spaces/An-619/FastSAM/resolve/main/weights/FastSAM-s.pt"
20
+ FASTSAM_WEIGHTS_NAME = "FastSAM-s.pt"
21
+
22
+ # Default FastSAM parameters
23
+ DEFAULT_IMGSZ = 640
24
+ DEFAULT_CONFIDENCE = 0.4
25
+ DEFAULT_IOU = 0.9
26
+ DEFAULT_RETINA_MASKS = False
27
+
28
+ # --- Helper Functions ---
29
+
30
+ def download_file(url, filename):
31
+ """Downloads a file from a URL with a progress bar."""
32
+ response = requests.get(url, stream=True)
33
+ response.raise_for_status() # Raise an exception for bad status codes
34
+
35
+ total_size = int(response.headers.get('content-length', 0))
36
+ block_size = 1024 # 1 KB
37
+ progress_bar = tqdm(total=total_size, unit='iB', unit_scale=True)
38
+
39
+ with open(filename, 'wb') as file:
40
+ for data in response.iter_content(block_size):
41
+ progress_bar.update(len(data))
42
+ file.write(data)
43
+ progress_bar.close()
44
+
45
+ if total_size != 0 and progress_bar.n != total_size:
46
+ raise ValueError("Error: Download failed.")
47
+
48
+ # --- Model Loading ---
49
+
50
+ # Load CLIP model (this part is correct in your original code)
51
+ model = CLIPModel.from_pretrained(CLIP_MODEL_NAME)
52
+ processor = CLIPProcessor.from_pretrained(CLIP_MODEL_NAME)
53
+
54
+ # Load FastSAM model with dynamic device handling
55
+ if not os.path.exists(FASTSAM_WEIGHTS_NAME):
56
+ print(f"Downloading FastSAM weights from {FASTSAM_WEIGHTS_URL}...")
57
+ try:
58
+ download_file(FASTSAM_WEIGHTS_URL, FASTSAM_WEIGHTS_NAME)
59
+ print("FastSAM weights downloaded successfully.")
60
+ except Exception as e:
61
+ print(f"Error downloading FastSAM weights: {e}")
62
+ raise # Re-raise the exception to stop execution
63
+
64
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
65
+ fast_sam = FastSAM(FASTSAM_WEIGHTS_NAME)
66
+ fast_sam.to(device)
67
+ print(f"FastSAM loaded on device: {device}")
68
+
69
+ # --- Processing Functions ---
70
 
71
  def process_image_clip(image, text_input):
72
+ # ... (Your CLIP processing function remains the same) ...
73
  if image is None:
74
  return "Please upload an image first."
75
  if not text_input:
76
  return "Please enter some text to check in the image."
77
+
78
  try:
79
  # Convert numpy array to PIL Image if needed
80
  if isinstance(image, np.ndarray):
81
  image = Image.fromarray(image)
82
+
83
  # Create a list of candidate labels
84
  candidate_labels = [text_input, f"not {text_input}"]
85
+
86
  # Process image and text
87
  inputs = processor(
88
  images=image,
 
90
  return_tensors="pt",
91
  padding=True
92
  )
93
+
94
  # Get model predictions
95
  outputs = model(**{k: v for k, v in inputs.items()})
96
  logits_per_image = outputs.logits_per_image
97
  probs = logits_per_image.softmax(dim=1)
98
+
99
  # Get confidence for the positive label
100
  confidence = float(probs[0][0])
101
  return f"Confidence that the image contains '{text_input}': {confidence:.2%}"
102
  except Exception as e:
103
  return f"Error processing image: {str(e)}"
104
 
105
+ def process_image_fastsam(image, imgsz, conf, iou, retina_masks):
106
  if image is None:
107
+ return None, "Please upload an image to segment."
108
+
109
  try:
110
  # Convert PIL image to numpy array if needed
111
  if isinstance(image, Image.Image):
112
  image_np = np.array(image)
113
  else:
114
  image_np = image
115
+
116
  # Run FastSAM inference
117
+ results = fast_sam(image_np, device=device, retina_masks=retina_masks, imgsz=imgsz, conf=conf, iou=iou)
118
+
119
+ # Check if results are valid
120
+ if results is None or len(results) == 0 or results[0] is None:
121
+ return None, "FastSAM did not return valid results. Try adjusting parameters or using a different image."
122
+
123
  # Get detections
124
  detections = sv.Detections.from_ultralytics(results[0])
125
+ # Check if detections are valid
126
+ if detections is None or len(detections) == 0:
127
+ return None, "No objects detected in the image. Try lowering the confidence threshold."
128
+
129
  # Create annotator
130
  box_annotator = sv.BoxAnnotator()
131
  mask_annotator = sv.MaskAnnotator()
132
+
133
  # Annotate image
134
  annotated_image = mask_annotator.annotate(scene=image_np.copy(), detections=detections)
135
  annotated_image = box_annotator.annotate(scene=annotated_image, detections=detections)
136
+
137
+ return Image.fromarray(annotated_image), None # Return None for the error message since there's no error
138
+
139
+ except RuntimeError as re:
140
+ if "out of memory" in str(re).lower():
141
+ return None, "Error: Out of memory. Try reducing the image size (imgsz) or disabling retina masks."
142
+ else:
143
+ return None, f"Runtime error during FastSAM processing: {str(re)}"
144
+
145
  except Exception as e:
146
+ return None, f"Error processing image with FastSAM: {str(e)}"
147
+
148
+ # --- Gradio Interface ---
149
 
 
150
  with gr.Blocks(css="footer {visibility: hidden}") as demo:
151
+ # ... (Your Markdown and CLIP tab remain mostly the same) ...
152
  gr.Markdown("""
153
  # CLIP and FastSAM Demo
154
  This demo combines two powerful AI models:
155
  - **CLIP**: For zero-shot image classification
156
  - **FastSAM**: For automatic image segmentation
157
+
158
  Try uploading an image and use either of the tabs below!
159
  """)
160
+
161
  with gr.Tab("CLIP Zero-Shot Classification"):
162
  with gr.Row():
163
  image_input = gr.Image(label="Input Image")
164
  text_input = gr.Textbox(
165
+ label="What do you want to check in the image?",
166
  placeholder="e.g., 'a dog', 'sunset', 'people playing'",
167
  info="Enter any concept you want to check in the image"
168
  )
169
  output_text = gr.Textbox(label="Result")
170
  classify_btn = gr.Button("Classify")
171
  classify_btn.click(fn=process_image_clip, inputs=[image_input, text_input], outputs=output_text)
172
+
173
  gr.Examples(
174
  examples=[
175
  ["https://raw.githubusercontent.com/gradio-app/gradio/main/demo/kitchen/kitchen.png", "kitchen"],
 
177
  ],
178
  inputs=[image_input, text_input],
179
  )
180
+
181
  with gr.Tab("FastSAM Segmentation"):
182
  with gr.Row():
183
  image_input_sam = gr.Image(label="Input Image")
184
+ with gr.Column():
185
+ imgsz_slider = gr.Slider(minimum=320, maximum=1920, step=32, value=DEFAULT_IMGSZ, label="Image Size (imgsz)")
186
+ conf_slider = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=DEFAULT_CONFIDENCE, label="Confidence Threshold")
187
+ iou_slider = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=DEFAULT_IOU, label="IoU Threshold")
188
+ retina_checkbox = gr.Checkbox(label="Retina Masks", value=DEFAULT_RETINA_MASKS)
189
+
190
+ with gr.Row():
191
+ image_output = gr.Image(label="Segmentation Result")
192
+ error_output = gr.Textbox(label="Error Message", type="text") # Added for displaying errors
193
+
194
  segment_btn = gr.Button("Segment")
195
+ segment_btn.click(
196
+ fn=process_image_fastsam,
197
+ inputs=[image_input_sam, imgsz_slider, conf_slider, iou_slider, retina_checkbox],
198
+ outputs=[image_output, error_output] # Output to both image and error textboxes
199
+ )
200
+
201
  gr.Examples(
202
  examples=[
203
  ["https://raw.githubusercontent.com/gradio-app/gradio/main/demo/kitchen/kitchen.png"],
 
205
  ],
206
  inputs=[image_input_sam],
207
  )
208
+
209
+ # ... (Your final Markdown remains the same) ...
210
  gr.Markdown("""
211
  ### How to use:
212
  1. **CLIP Classification**: Upload an image and enter text to check if that concept exists in the image
213
  2. **FastSAM Segmentation**: Upload an image to get automatic segmentation with bounding boxes and masks
214
+
215
  ### Note:
216
+ - The models run on CPU by default, so processing might take a few seconds. If you have a GPU, it will be used automatically.
217
+ - For best results, use clear images with good lighting.
218
+ - You can adjust FastSAM parameters (Image Size, Confidence, IoU, Retina Masks) in the Segmentation tab.
219
  """)
220
 
221
+ demo.launch(share=True)