SegmentVision / app.py
sagar007's picture
Update app.py
b066832 verified
raw
history blame
12.4 kB
import gradio as gr
import torch
from transformers import AutoProcessor, AutoModel
from PIL import Image, ImageDraw, ImageFont
import numpy as np
import random
import os
import wget # To download weights
# --- Configuration & Model Loading ---
# Device Selection
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {DEVICE}")
# --- CLIP Setup ---
CLIP_MODEL_ID = "openai/clip-vit-base-patch32"
clip_processor = None
clip_model = None
def load_clip_model():
global clip_processor, clip_model
if clip_processor is None:
print(f"Loading CLIP processor: {CLIP_MODEL_ID}...")
clip_processor = AutoProcessor.from_pretrained(CLIP_MODEL_ID)
print("CLIP processor loaded.")
if clip_model is None:
print(f"Loading CLIP model: {CLIP_MODEL_ID}...")
clip_model = AutoModel.from_pretrained(CLIP_MODEL_ID).to(DEVICE)
print(f"CLIP model loaded to {DEVICE}.")
# --- FastSAM Setup ---
# Use a smaller model suitable for Spaces CPU/basic GPU if needed
FASTSAM_CHECKPOINT = "FastSAM-s.pt"
FASTSAM_CHECKPOINT_URL = f"https://huggingface.co/spaces/An-619/FastSAM/resolve/main/{FASTSAM_CHECKPOINT}" # Example URL, find official if possible
fastsam_model = None
def download_fastsam_weights():
if not os.path.exists(FASTSAM_CHECKPOINT):
print(f"Downloading FastSAM weights: {FASTSAM_CHECKPOINT}...")
try:
wget.download(FASTSAM_CHECKPOINT_URL, FASTSAM_CHECKPOINT)
print("FastSAM weights downloaded.")
except Exception as e:
print(f"Error downloading FastSAM weights: {e}")
print("Please ensure the URL is correct and reachable, or manually place the weights file.")
return False
return os.path.exists(FASTSAM_CHECKPOINT)
def load_fastsam_model():
global fastsam_model
if fastsam_model is None:
if download_fastsam_weights():
try:
from fastsam import FastSAM, FastSAMPrompt # Import here after potential download
print(f"Loading FastSAM model: {FASTSAM_CHECKPOINT}...")
fastsam_model = FastSAM(FASTSAM_CHECKPOINT)
print(f"FastSAM model loaded.") # Device handled internally by FastSAM based on its setup/torch device
except ImportError:
print("Error: 'fastsam' library not found. Please install it (pip install fastsam).")
except Exception as e:
print(f"Error loading FastSAM model: {e}")
else:
print("FastSAM weights not found. Cannot load model.")
# --- Processing Functions ---
# CLIP Zero-Shot Classification Function
def run_clip_zero_shot(image: Image.Image, text_labels: str):
if clip_model is None or clip_processor is None:
load_clip_model() # Attempt to load if not already loaded
if clip_model is None:
return "Error: CLIP Model not loaded. Check logs.", None
if not text_labels:
return "Please provide comma-separated text labels.", None
if image is None:
return "Please upload an image.", None
labels = [label.strip() for label in text_labels.split(',')]
if not labels:
return "No valid labels provided.", None
print(f"Running CLIP zero-shot classification with labels: {labels}")
try:
# Ensure image is RGB
if image.mode != "RGB":
image = image.convert("RGB")
inputs = clip_processor(text=labels, images=image, return_tensors="pt", padding=True).to(DEVICE)
with torch.no_grad():
outputs = clip_model(**inputs)
logits_per_image = outputs.logits_per_image # this is the image-text similarity score
probs = logits_per_image.softmax(dim=1) # convert to probabilities
print("CLIP processing complete.")
# Format output for Gradio Label
confidences = {labels[i]: float(probs[0, i].item()) for i in range(len(labels))}
return confidences, image # Return original image for display alongside results
except Exception as e:
print(f"Error during CLIP processing: {e}")
return f"An error occurred: {e}", None
# FastSAM Segmentation Function
def run_fastsam_segmentation(image_pil: Image.Image, conf_threshold: float = 0.4, iou_threshold: float = 0.9):
if fastsam_model is None:
load_fastsam_model() # Attempt to load if not already loaded
if fastsam_model is None:
return "Error: FastSAM Model not loaded. Check logs.", None
if image_pil is None:
return "Please upload an image.", None
print("Running FastSAM segmentation...")
try:
# Ensure image is RGB
if image_pil.mode != "RGB":
image_pil = image_pil.convert("RGB")
# FastSAM expects a BGR numpy array or path usually. Let's try with RGB numpy.
# If it fails, uncomment the BGR conversion line.
image_np_rgb = np.array(image_pil)
# image_np_bgr = image_np_rgb[:, :, ::-1] # Convert RGB to BGR if needed
# Run FastSAM inference
# Adjust imgsz, conf, iou as needed. Higher imgsz = more detail, slower.
everything_results = fastsam_model(
image_np_rgb, # Use image_np_bgr if conversion needed
device=DEVICE,
retina_masks=True,
imgsz=640, # Smaller size for faster inference on limited hardware
conf=conf_threshold,
iou=iou_threshold,
)
# Process results using FastSAMPrompt
from fastsam import FastSAMPrompt # Make sure it's imported
prompt_process = FastSAMPrompt(image_np_rgb, everything_results, device=DEVICE)
# Get all annotations (masks)
ann = prompt_process.everything_prompt()
print(f"FastSAM found {len(ann[0]['masks']) if ann and ann[0] else 0} masks.")
# --- Plotting Masks on Image (Manual) ---
output_image = image_pil.copy()
if ann and ann[0] is not None and 'masks' in ann[0] and len(ann[0]['masks']) > 0:
masks = ann[0]['masks'].cpu().numpy() # shape (N, H, W)
# Create overlay image
overlay = Image.new('RGBA', output_image.size, (0, 0, 0, 0))
draw = ImageDraw.Draw(overlay)
for i in range(masks.shape[0]):
mask = masks[i] # shape (H, W), boolean
# Generate random color with some transparency
color = (random.randint(50, 255), random.randint(50, 255), random.randint(50, 255), 128) # RGBA with alpha
# Create a single-channel image from the boolean mask
mask_image = Image.fromarray((mask * 255).astype(np.uint8), mode='L')
# Apply color to the mask area on the overlay
draw.bitmap((0,0), mask_image, fill=color)
# Composite the overlay onto the original image
output_image = Image.alpha_composite(output_image.convert('RGBA'), overlay).convert('RGB')
print("FastSAM processing and plotting complete.")
return output_image, image_pil # Return segmented and original images
except Exception as e:
print(f"Error during FastSAM processing: {e}")
import traceback
traceback.print_exc() # Print detailed traceback
return f"An error occurred: {e}", None
# --- Gradio Interface ---
# Pre-load models on startup (optional but good for performance)
print("Attempting to preload models...")
load_clip_model()
load_fastsam_model()
print("Preloading finished (or attempted).")
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# CLIP & FastSAM Demo")
gr.Markdown("Explore Zero-Shot Classification with CLIP and 'Segment Anything' with FastSAM.")
with gr.Tabs():
# --- CLIP Tab ---
with gr.TabItem("CLIP Zero-Shot Classification"):
gr.Markdown("Upload an image and provide comma-separated candidate labels (e.g., 'cat, dog, car'). CLIP will predict the probability of the image matching each label.")
with gr.Row():
with gr.Column(scale=1):
clip_input_image = gr.Image(type="pil", label="Input Image")
clip_text_labels = gr.Textbox(label="Comma-Separated Labels", placeholder="e.g., astronaut, mountain, dog playing fetch")
clip_button = gr.Button("Run CLIP Classification", variant="primary")
with gr.Column(scale=1):
clip_output_label = gr.Label(label="Classification Probabilities")
clip_output_image_display = gr.Image(type="pil", label="Input Image Preview") # Show input for context
clip_button.click(
run_clip_zero_shot,
inputs=[clip_input_image, clip_text_labels],
outputs=[clip_output_label, clip_output_image_display]
)
gr.Examples(
examples=[
["examples/astronaut.jpg", "astronaut, moon, rover, mountain"],
["examples/dog_bike.jpg", "dog, bicycle, person, park, grass"],
],
inputs=[clip_input_image, clip_text_labels],
outputs=[clip_output_label, clip_output_image_display],
fn=run_clip_zero_shot,
cache_examples=False, # Re-run for live demo
)
# --- FastSAM Tab ---
with gr.TabItem("FastSAM Segmentation"):
gr.Markdown("Upload an image. FastSAM will attempt to segment all objects/regions in the image.")
with gr.Row():
with gr.Column(scale=1):
fastsam_input_image = gr.Image(type="pil", label="Input Image")
with gr.Row():
fastsam_conf = gr.Slider(minimum=0.1, maximum=1.0, value=0.4, step=0.05, label="Confidence Threshold")
fastsam_iou = gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.05, label="IoU Threshold")
fastsam_button = gr.Button("Run FastSAM Segmentation", variant="primary")
with gr.Column(scale=1):
fastsam_output_image = gr.Image(type="pil", label="Segmented Image")
# fastsam_input_display = gr.Image(type="pil", label="Original Image") # Optional: show original side-by-side
fastsam_button.click(
run_fastsam_segmentation,
inputs=[fastsam_input_image, fastsam_conf, fastsam_iou],
outputs=[fastsam_output_image] # Removed the second output for simplicity, adjust if needed
)
gr.Examples(
examples=[
["examples/dogs.jpg", 0.4, 0.9],
["examples/fruits.jpg", 0.5, 0.8],
],
inputs=[fastsam_input_image, fastsam_conf, fastsam_iou],
outputs=[fastsam_output_image],
fn=run_fastsam_segmentation,
cache_examples=False, # Re-run for live demo
)
# Add example images (optional, but helpful)
# Create an 'examples' folder and add some jpg images like 'astronaut.jpg', 'dog_bike.jpg', 'dogs.jpg', 'fruits.jpg'
if not os.path.exists("examples"):
os.makedirs("examples")
print("Created 'examples' directory. Please add some images (e.g., astronaut.jpg, dog_bike.jpg) for the examples to work.")
# You might need to download some sample images here too if running on a fresh env
try:
print("Downloading example images...")
wget.download("https://huggingface.co/spaces/gradio/image-segmentation/resolve/main/images/lion.jpg", "examples/lion.jpg")
wget.download("https://raw.githubusercontent.com/openai/CLIP/main/CLIP.png", "examples/clip_logo.png")
wget.download("https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/gradio-logo.png", "examples/gradio_logo.png")
# Manually add the examples used above if these don't match
print("Example images downloaded (or attempted). Please verify.")
except Exception as e:
print(f"Could not download example images: {e}")
# Launch the Gradio app
if __name__ == "__main__":
demo.launch(debug=True) # Set debug=False for deployment