Spaces:
Sleeping
Sleeping
File size: 8,663 Bytes
dfdcd97 a3ee867 c95f3e0 fd55cab 3701938 eefe5b4 6facde6 c95f3e0 6facde6 3cd1243 6facde6 e0d4d2f fd55cab 6facde6 3701938 eefe5b4 6facde6 3d6a9c7 6facde6 3d6a9c7 6facde6 3d6a9c7 6facde6 3d6a9c7 6facde6 3d6a9c7 72f4c5c 6facde6 3701938 6facde6 eefe5b4 3d6a9c7 6facde6 eefe5b4 6facde6 eefe5b4 6facde6 eefe5b4 6facde6 eefe5b4 6facde6 eefe5b4 6facde6 e0d4d2f eefe5b4 6facde6 3701938 6facde6 3701938 6facde6 fd55cab 3d6a9c7 eefe5b4 6facde6 eefe5b4 fd55cab 6facde6 3d6a9c7 6facde6 fd55cab 3d6a9c7 6facde6 fd55cab 6facde6 3d6a9c7 6facde6 3701938 6facde6 eefe5b4 6facde6 3701938 e0d4d2f 6facde6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 |
import gradio as gr
import torch
from PIL import Image
import cv2
import numpy as np
from transformers import CLIPProcessor, CLIPModel
from ultralytics import FastSAM
import supervision as sv
import os
import requests
from tqdm.auto import tqdm # For a nice progress bar
# --- Constants and Model Initialization ---
# CLIP
CLIP_MODEL_NAME = "openai/clip-vit-base-patch32"
# FastSAM
FASTSAM_WEIGHTS_URL = "https://huggingface.co/spaces/An-619/FastSAM/resolve/main/weights/FastSAM-s.pt"
FASTSAM_WEIGHTS_NAME = "FastSAM-s.pt"
# Default FastSAM parameters
DEFAULT_IMGSZ = 640
DEFAULT_CONFIDENCE = 0.4
DEFAULT_IOU = 0.9
DEFAULT_RETINA_MASKS = False
# --- Helper Functions ---
def download_file(url, filename):
"""Downloads a file from a URL with a progress bar."""
response = requests.get(url, stream=True)
response.raise_for_status() # Raise an exception for bad status codes
total_size = int(response.headers.get('content-length', 0))
block_size = 1024 # 1 KB
progress_bar = tqdm(total=total_size, unit='iB', unit_scale=True)
with open(filename, 'wb') as file:
for data in response.iter_content(block_size):
progress_bar.update(len(data))
file.write(data)
progress_bar.close()
if total_size != 0 and progress_bar.n != total_size:
raise ValueError("Error: Download failed.")
# --- Model Loading ---
# Load CLIP model (this part is correct in your original code)
model = CLIPModel.from_pretrained(CLIP_MODEL_NAME)
processor = CLIPProcessor.from_pretrained(CLIP_MODEL_NAME)
# Load FastSAM model with dynamic device handling
if not os.path.exists(FASTSAM_WEIGHTS_NAME):
print(f"Downloading FastSAM weights from {FASTSAM_WEIGHTS_URL}...")
try:
download_file(FASTSAM_WEIGHTS_URL, FASTSAM_WEIGHTS_NAME)
print("FastSAM weights downloaded successfully.")
except Exception as e:
print(f"Error downloading FastSAM weights: {e}")
raise # Re-raise the exception to stop execution
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
fast_sam = FastSAM(FASTSAM_WEIGHTS_NAME)
fast_sam.to(device)
print(f"FastSAM loaded on device: {device}")
# --- Processing Functions ---
def process_image_clip(image, text_input):
# ... (Your CLIP processing function remains the same) ...
if image is None:
return "Please upload an image first."
if not text_input:
return "Please enter some text to check in the image."
try:
# Convert numpy array to PIL Image if needed
if isinstance(image, np.ndarray):
image = Image.fromarray(image)
# Create a list of candidate labels
candidate_labels = [text_input, f"not {text_input}"]
# Process image and text
inputs = processor(
images=image,
text=candidate_labels,
return_tensors="pt",
padding=True
)
# Get model predictions
outputs = model(**{k: v for k, v in inputs.items()})
logits_per_image = outputs.logits_per_image
probs = logits_per_image.softmax(dim=1)
# Get confidence for the positive label
confidence = float(probs[0][0])
return f"Confidence that the image contains '{text_input}': {confidence:.2%}"
except Exception as e:
return f"Error processing image: {str(e)}"
def process_image_fastsam(image, imgsz, conf, iou, retina_masks):
if image is None:
return None, "Please upload an image to segment."
try:
# Convert PIL image to numpy array if needed
if isinstance(image, Image.Image):
image_np = np.array(image)
else:
image_np = image
# Run FastSAM inference
results = fast_sam(image_np, device=device, retina_masks=retina_masks, imgsz=imgsz, conf=conf, iou=iou)
# Check if results are valid
if results is None or len(results) == 0 or results[0] is None:
return None, "FastSAM did not return valid results. Try adjusting parameters or using a different image."
# Get detections
detections = sv.Detections.from_ultralytics(results[0])
# Check if detections are valid
if detections is None or len(detections) == 0:
return None, "No objects detected in the image. Try lowering the confidence threshold."
# Create annotator
box_annotator = sv.BoxAnnotator()
mask_annotator = sv.MaskAnnotator()
# Annotate image
annotated_image = mask_annotator.annotate(scene=image_np.copy(), detections=detections)
annotated_image = box_annotator.annotate(scene=annotated_image, detections=detections)
return Image.fromarray(annotated_image), None # Return None for the error message since there's no error
except RuntimeError as re:
if "out of memory" in str(re).lower():
return None, "Error: Out of memory. Try reducing the image size (imgsz) or disabling retina masks."
else:
return None, f"Runtime error during FastSAM processing: {str(re)}"
except Exception as e:
return None, f"Error processing image with FastSAM: {str(e)}"
# --- Gradio Interface ---
with gr.Blocks(css="footer {visibility: hidden}") as demo:
# ... (Your Markdown and CLIP tab remain mostly the same) ...
gr.Markdown("""
# CLIP and FastSAM Demo
This demo combines two powerful AI models:
- **CLIP**: For zero-shot image classification
- **FastSAM**: For automatic image segmentation
Try uploading an image and use either of the tabs below!
""")
with gr.Tab("CLIP Zero-Shot Classification"):
with gr.Row():
image_input = gr.Image(label="Input Image")
text_input = gr.Textbox(
label="What do you want to check in the image?",
placeholder="e.g., 'a dog', 'sunset', 'people playing'",
info="Enter any concept you want to check in the image"
)
output_text = gr.Textbox(label="Result")
classify_btn = gr.Button("Classify")
classify_btn.click(fn=process_image_clip, inputs=[image_input, text_input], outputs=output_text)
gr.Examples(
examples=[
["https://raw.githubusercontent.com/gradio-app/gradio/main/demo/kitchen/kitchen.png", "kitchen"],
["https://raw.githubusercontent.com/gradio-app/gradio/main/demo/calculator/calculator.jpg", "calculator"],
],
inputs=[image_input, text_input],
)
with gr.Tab("FastSAM Segmentation"):
with gr.Row():
image_input_sam = gr.Image(label="Input Image")
with gr.Column():
imgsz_slider = gr.Slider(minimum=320, maximum=1920, step=32, value=DEFAULT_IMGSZ, label="Image Size (imgsz)")
conf_slider = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=DEFAULT_CONFIDENCE, label="Confidence Threshold")
iou_slider = gr.Slider(minimum=0.0, maximum=1.0, step=0.05, value=DEFAULT_IOU, label="IoU Threshold")
retina_checkbox = gr.Checkbox(label="Retina Masks", value=DEFAULT_RETINA_MASKS)
with gr.Row():
image_output = gr.Image(label="Segmentation Result")
error_output = gr.Textbox(label="Error Message", type="text") # Added for displaying errors
segment_btn = gr.Button("Segment")
segment_btn.click(
fn=process_image_fastsam,
inputs=[image_input_sam, imgsz_slider, conf_slider, iou_slider, retina_checkbox],
outputs=[image_output, error_output] # Output to both image and error textboxes
)
gr.Examples(
examples=[
["https://raw.githubusercontent.com/gradio-app/gradio/main/demo/kitchen/kitchen.png"],
["https://raw.githubusercontent.com/gradio-app/gradio/main/demo/calculator/calculator.jpg"],
],
inputs=[image_input_sam],
)
# ... (Your final Markdown remains the same) ...
gr.Markdown("""
### How to use:
1. **CLIP Classification**: Upload an image and enter text to check if that concept exists in the image
2. **FastSAM Segmentation**: Upload an image to get automatic segmentation with bounding boxes and masks
### Note:
- The models run on CPU by default, so processing might take a few seconds. If you have a GPU, it will be used automatically.
- For best results, use clear images with good lighting.
- You can adjust FastSAM parameters (Image Size, Confidence, IoU, Retina Masks) in the Segmentation tab.
""")
demo.launch(share=True) |