Spaces:
Sleeping
Sleeping
File size: 13,947 Bytes
dfdcd97 a3ee867 b066832 fd55cab b066832 eefe5b4 b066832 eba2946 b066832 eba2946 b066832 eba2946 b066832 eba2946 b066832 eba2946 b066832 eba2946 b066832 eba2946 b066832 eba2946 b066832 eba2946 b066832 eba2946 c95f3e0 3cd1243 b066832 6facde6 b066832 6facde6 b066832 eba2946 6facde6 eba2946 b066832 eba2946 6facde6 b066832 6facde6 b066832 6facde6 b066832 6facde6 b066832 eba2946 6facde6 b066832 6facde6 b066832 eba2946 6facde6 b066832 eba2946 6facde6 b066832 eba2946 b066832 eba2946 b066832 eba2946 b066832 eba2946 e0d4d2f b066832 6facde6 3d6a9c7 b066832 6facde6 b066832 72f4c5c b066832 eba2946 b066832 eba2946 b066832 6facde6 eba2946 b066832 6facde6 eba2946 e31b682 eba2946 b066832 eba2946 b066832 6facde6 b066832 eba2946 b066832 6facde6 b066832 6facde6 b066832 eba2946 6facde6 eba2946 eefe5b4 b066832 eba2946 b066832 6facde6 e0d4d2f b066832 eba2946 b066832 eba2946 b066832 eba2946 b066832 eba2946 b066832 eba2946 eefe5b4 6facde6 b066832 eba2946 b066832 eba2946 b066832 eba2946 b066832 6facde6 b066832 eba2946 b066832 eba2946 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 |
import gradio as gr
import torch
from transformers import AutoProcessor, AutoModel
from PIL import Image, ImageDraw, ImageFont
import numpy as np
import random
import os
import wget # To download weights
import traceback # For detailed error printing
# --- Configuration & Model Loading ---
# Device Selection
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {DEVICE}")
# --- CLIP Setup ---
CLIP_MODEL_ID = "openai/clip-vit-base-patch32"
clip_processor = None
clip_model = None
def load_clip_model():
global clip_processor, clip_model
if clip_processor is None:
print(f"Loading CLIP processor: {CLIP_MODEL_ID}...")
clip_processor = AutoProcessor.from_pretrained(CLIP_MODEL_ID)
print("CLIP processor loaded.")
if clip_model is None:
print(f"Loading CLIP model: {CLIP_MODEL_ID}...")
clip_model = AutoModel.from_pretrained(CLIP_MODEL_ID).to(DEVICE)
print(f"CLIP model loaded to {DEVICE}.")
# --- FastSAM Setup ---
FASTSAM_CHECKPOINT = "FastSAM-s.pt"
# Use the official model hub repo URL
FASTSAM_CHECKPOINT_URL = f"https://huggingface.co/CASIA-IVA-Lab/FastSAM-s/resolve/main/{FASTSAM_CHECKPOINT}"
fastsam_model = None
fastsam_lib_imported = False # Flag to check if import worked
def check_and_import_fastsam():
global fastsam_lib_imported
if not fastsam_lib_imported:
try:
from fastsam import FastSAM, FastSAMPrompt
globals()['FastSAM'] = FastSAM # Make classes available globally
globals()['FastSAMPrompt'] = FastSAMPrompt
fastsam_lib_imported = True
print("fastsam library imported successfully.")
except ImportError:
print("Error: 'fastsam' library not found or import failed.")
print("Please ensure 'fastsam' is installed correctly (pip install fastsam).")
fastsam_lib_imported = False
except Exception as e:
print(f"An unexpected error occurred during fastsam import: {e}")
fastsam_lib_imported = False
return fastsam_lib_imported
def download_fastsam_weights():
if not os.path.exists(FASTSAM_CHECKPOINT):
print(f"Downloading FastSAM weights: {FASTSAM_CHECKPOINT} from {FASTSAM_CHECKPOINT_URL}...")
try:
wget.download(FASTSAM_CHECKPOINT_URL, FASTSAM_CHECKPOINT)
print("FastSAM weights downloaded.")
except Exception as e:
print(f"Error downloading FastSAM weights: {e}")
print("Please ensure the URL is correct and reachable, or manually place the weights file.")
# Attempt to remove partially downloaded file if exists
if os.path.exists(FASTSAM_CHECKPOINT):
try:
os.remove(FASTSAM_CHECKPOINT)
except OSError:
pass # Ignore removal errors
return False
return os.path.exists(FASTSAM_CHECKPOINT)
def load_fastsam_model():
global fastsam_model
if fastsam_model is None:
if not check_and_import_fastsam(): # Check import first
print("Cannot load FastSAM model because the library couldn't be imported.")
return # Exit if import failed
if download_fastsam_weights(): # Check download/existence second
try:
# FastSAM class should be available via globals() now
print(f"Loading FastSAM model: {FASTSAM_CHECKPOINT}...")
fastsam_model = FastSAM(FASTSAM_CHECKPOINT)
print(f"FastSAM model loaded.") # Device handled internally by FastSAM
except Exception as e:
print(f"Error loading FastSAM model: {e}")
traceback.print_exc()
else:
print("FastSAM weights not found or download failed. Cannot load model.")
# --- Processing Functions ---
# CLIP Zero-Shot Classification Function
def run_clip_zero_shot(image: Image.Image, text_labels: str):
if clip_model is None or clip_processor is None:
load_clip_model() # Attempt to load if not already loaded
if clip_model is None:
return "Error: CLIP Model not loaded. Check logs.", None
if image is None:
return "Please upload an image.", None # Return None for the image display
if not text_labels:
# Return empty results but display the uploaded image
return {}, image
labels = [label.strip() for label in text_labels.split(',') if label.strip()] # Ensure non-empty labels
if not labels:
# Return empty results but display the uploaded image
return {}, image
print(f"Running CLIP zero-shot classification with labels: {labels}")
try:
# Ensure image is RGB
if image.mode != "RGB":
image = image.convert("RGB")
inputs = clip_processor(text=labels, images=image, return_tensors="pt", padding=True).to(DEVICE)
with torch.no_grad():
outputs = clip_model(**inputs)
logits_per_image = outputs.logits_per_image
probs = logits_per_image.softmax(dim=1)
print("CLIP processing complete.")
confidences = {labels[i]: float(probs[0, i].item()) for i in range(len(labels))}
# Return results and the original image used for prediction
return confidences, image
except Exception as e:
print(f"Error during CLIP processing: {e}")
traceback.print_exc()
# Return error message and the original image
return f"An error occurred during CLIP: {e}", image
# FastSAM Segmentation Function
def run_fastsam_segmentation(image_pil: Image.Image, conf_threshold: float = 0.4, iou_threshold: float = 0.9):
# Ensure model is loaded or attempt to load
if fastsam_model is None:
load_fastsam_model()
if fastsam_model is None:
# Return error message string for the image component (Gradio handles this)
return "Error: FastSAM Model not loaded. Check logs."
# Ensure library was imported
if not fastsam_lib_imported:
return "Error: FastSAM library not available. Cannot run segmentation."
if image_pil is None:
return "Please upload an image."
print("Running FastSAM segmentation...")
try:
# Ensure image is RGB
if image_pil.mode != "RGB":
image_pil = image_pil.convert("RGB")
image_np_rgb = np.array(image_pil)
# Run FastSAM inference
everything_results = fastsam_model(
image_np_rgb,
device=DEVICE,
retina_masks=True,
imgsz=640,
conf=conf_threshold,
iou=iou_threshold,
)
# FastSAMPrompt should be available via globals() if import succeeded
prompt_process = FastSAMPrompt(image_np_rgb, everything_results, device=DEVICE)
ann = prompt_process.everything_prompt()
print(f"FastSAM found {len(ann[0]['masks']) if ann and ann[0] and 'masks' in ann[0] else 0} masks.")
# --- Plotting Masks on Image ---
output_image = image_pil.copy()
if ann and ann[0] is not None and 'masks' in ann[0] and len(ann[0]['masks']) > 0:
masks = ann[0]['masks'].cpu().numpy() # (N, H, W) boolean
overlay = Image.new('RGBA', output_image.size, (0, 0, 0, 0))
draw = ImageDraw.Draw(overlay)
for i in range(masks.shape[0]):
mask = masks[i]
color = (random.randint(50, 255), random.randint(50, 255), random.randint(50, 255), 128) # RGBA
mask_image = Image.fromarray((mask * 255).astype(np.uint8), mode='L')
draw.bitmap((0,0), mask_image, fill=color)
output_image = Image.alpha_composite(output_image.convert('RGBA'), overlay).convert('RGB')
print("FastSAM processing and plotting complete.")
# *** FIX: Return ONLY the output image for the single Image component ***
return output_image
except NameError as ne:
print(f"NameError during FastSAM processing: {ne}. Was the fastsam library imported correctly?")
traceback.print_exc()
return f"A NameError occurred: {ne}. Check library import."
except Exception as e:
print(f"Error during FastSAM processing: {e}")
traceback.print_exc()
return f"An error occurred during FastSAM: {e}"
# --- Gradio Interface ---
# Pre-load models on startup (optional but good for performance)
print("Attempting to preload models...")
load_clip_model()
load_fastsam_model() # This will now also attempt download/check import
print("Preloading finished (or attempted).")
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# CLIP & FastSAM Demo")
gr.Markdown("Explore Zero-Shot Classification with CLIP and 'Segment Anything' with FastSAM.")
with gr.Tabs():
# --- CLIP Tab ---
with gr.TabItem("CLIP Zero-Shot Classification"):
gr.Markdown("Upload an image and provide comma-separated candidate labels (e.g., 'cat, dog, car'). CLIP will predict the probability of the image matching each label.")
with gr.Row():
with gr.Column(scale=1):
clip_input_image = gr.Image(type="pil", label="Input Image")
clip_text_labels = gr.Textbox(label="Comma-Separated Labels", placeholder="e.g., astronaut, moon, dog playing fetch")
clip_button = gr.Button("Run CLIP Classification", variant="primary")
with gr.Column(scale=1):
clip_output_label = gr.Label(label="Classification Probabilities")
clip_output_image_display = gr.Image(type="pil", label="Input Image Preview")
clip_button.click(
run_clip_zero_shot,
inputs=[clip_input_image, clip_text_labels],
outputs=[clip_output_label, clip_output_image_display]
)
gr.Examples(
examples=[
["examples/astronaut.jpg", "astronaut, moon, rover, mountain"],
["examples/dog_bike.jpg", "dog, bicycle, person, park, grass"],
["examples/clip_logo.png", "logo, text, graphics, abstract art"], # Added another example
],
inputs=[clip_input_image, clip_text_labels],
outputs=[clip_output_label, clip_output_image_display],
fn=run_clip_zero_shot,
cache_examples=False,
)
# --- FastSAM Tab ---
with gr.TabItem("FastSAM Segmentation"):
gr.Markdown("Upload an image. FastSAM will attempt to segment all objects/regions in the image.")
with gr.Row():
with gr.Column(scale=1):
fastsam_input_image = gr.Image(type="pil", label="Input Image")
with gr.Row():
fastsam_conf = gr.Slider(minimum=0.1, maximum=1.0, value=0.4, step=0.05, label="Confidence Threshold")
fastsam_iou = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="IoU Threshold")
fastsam_button = gr.Button("Run FastSAM Segmentation", variant="primary")
with gr.Column(scale=1):
fastsam_output_image = gr.Image(type="pil", label="Segmented Image")
fastsam_button.click(
run_fastsam_segmentation,
inputs=[fastsam_input_image, fastsam_conf, fastsam_iou],
# Output is now correctly mapped to the single component
outputs=[fastsam_output_image]
)
gr.Examples(
examples=[
["examples/dogs.jpg", 0.4, 0.9],
["examples/fruits.jpg", 0.5, 0.8],
["examples/lion.jpg", 0.45, 0.9], # Added another example
],
inputs=[fastsam_input_image, fastsam_conf, fastsam_iou],
outputs=[fastsam_output_image],
fn=run_fastsam_segmentation,
cache_examples=False,
)
# Add example images (optional, but helpful)
if not os.path.exists("examples"):
os.makedirs("examples")
print("Created 'examples' directory. Attempting to download sample images...")
example_files = {
"astronaut.jpg": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/d astronaut_-_St._Jean_Bay.jpg/640px-Astronaut_-_St._Jean_Bay.jpg", # Find suitable public domain/CC image
"dog_bike.jpg": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/gradio/outputs_multimodal.jpg", # Using a relevant example from HF
"clip_logo.png": "https://raw.githubusercontent.com/openai/CLIP/main/CLIP.png",
"dogs.jpg": "https://raw.githubusercontent.com/ultralytics/assets/main/im/image8.jpg", # From Ultralytics assets
"fruits.jpg": "https://raw.githubusercontent.com/ultralytics/assets/main/im/image9.jpg", # From Ultralytics assets
"lion.jpg": "https://huggingface.co/spaces/gradio/image-segmentation/resolve/main/images/lion.jpg"
}
for filename, url in example_files.items():
filepath = os.path.join("examples", filename)
if not os.path.exists(filepath):
try:
print(f"Downloading {filename}...")
wget.download(url, filepath)
except Exception as e:
print(f"Could not download {filename} from {url}: {e}")
print("Example image download attempt finished.")
# Launch the Gradio app
if __name__ == "__main__":
# share=True is primarily for local testing to get a public link.
# Not needed/used when deploying on Hugging Face Spaces.
# debug=True is helpful for development. Set to False for production.
demo.launch(debug=True) |