codemaker2015's picture
Rename gradio_app.py to app.py
129c740 verified
import gradio as gr
import torch
from transformers import CLIPProcessor, CLIPModel
from datasets import load_dataset
from PIL import Image
import requests
import matplotlib.pyplot as plt
import os
import glob
from pathlib import Path
import numpy as np
import io
import base64
# Global variables for model and data
model = None
processor = None
device = None
demo_data = None
demo_text_emb = None
demo_image_emb = None
# Custom folder data
custom_images = []
custom_descriptions = []
custom_paths = []
custom_image_emb = None
current_data_source = "demo"
def load_model_and_demo_data():
"""Load CLIP model and demo dataset"""
global model, processor, device, demo_data, demo_text_emb, demo_image_emb
try:
# Load dataset
demo_data = load_dataset("jamescalam/image-text-demo", split="train")
# Load model
model_id = "openai/clip-vit-base-patch32"
processor = CLIPProcessor.from_pretrained(model_id)
model = CLIPModel.from_pretrained(model_id)
# Move to device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.to(device)
# Pre-compute image embeddings
text = demo_data['text']
images = demo_data['image']
inputs = processor(
text=text,
images=images,
return_tensors="pt",
padding=True,
).to(device)
outputs = model(**inputs)
# Normalize embeddings
demo_text_emb = outputs.text_embeds
demo_text_emb = demo_text_emb / torch.norm(demo_text_emb, dim=1, keepdim=True)
demo_image_emb = outputs.image_embeds
demo_image_emb = demo_image_emb / torch.norm(demo_image_emb, dim=1, keepdim=True)
return f"βœ… Model loaded successfully on {device.upper()}. Demo dataset: {len(demo_data)} images."
except Exception as e:
return f"❌ Error loading model: {str(e)}"
def load_custom_folder(folder_path):
"""Load images from a custom folder"""
global custom_images, custom_descriptions, custom_paths, custom_image_emb, current_data_source
if not folder_path or not os.path.exists(folder_path):
return "❌ Invalid folder path"
try:
supported_formats = ['*.jpg', '*.jpeg', '*.png', '*.bmp', '*.gif', '*.tiff']
image_paths = []
# Get all image files from the folder
for format_type in supported_formats:
image_paths.extend(glob.glob(os.path.join(folder_path, format_type)))
image_paths.extend(glob.glob(os.path.join(folder_path, format_type.upper())))
# Also search in subdirectories
for format_type in supported_formats:
image_paths.extend(glob.glob(os.path.join(folder_path, '**', format_type), recursive=True))
image_paths.extend(glob.glob(os.path.join(folder_path, '**', format_type.upper()), recursive=True))
# Remove duplicates and sort
image_paths = sorted(list(set(image_paths)))
if not image_paths:
return "❌ No valid images found in the specified folder"
# Load images
custom_images.clear()
custom_descriptions.clear()
custom_paths.clear()
for img_path in image_paths[:100]: # Limit to 100 images for demo
try:
img = Image.open(img_path).convert('RGB')
custom_images.append(img)
filename = Path(img_path).stem
custom_descriptions.append(f"Image: {filename}")
custom_paths.append(img_path)
except Exception as e:
continue
if not custom_images:
return "❌ No valid images could be loaded"
# Compute embeddings
custom_image_emb = compute_custom_embeddings(custom_images, custom_descriptions)
current_data_source = "custom"
return f"βœ… Loaded {len(custom_images)} images from custom folder"
except Exception as e:
return f"❌ Error loading custom folder: {str(e)}"
def compute_custom_embeddings(images, descriptions):
"""Compute embeddings for custom images"""
try:
batch_size = 8
all_image_embeddings = []
for i in range(0, len(images), batch_size):
batch_images = images[i:i+batch_size]
batch_texts = descriptions[i:i+batch_size]
inputs = processor(
text=batch_texts,
images=batch_images,
return_tensors="pt",
padding=True,
).to(device)
with torch.no_grad():
outputs = model(**inputs)
image_emb = outputs.image_embeds
image_emb = image_emb / torch.norm(image_emb, dim=1, keepdim=True)
all_image_embeddings.append(image_emb.cpu())
return torch.cat(all_image_embeddings, dim=0).to(device)
except Exception as e:
print(f"Error computing embeddings: {str(e)}")
return None
def search_images_by_text(query_text, top_k=5, data_source="demo"):
"""Search images based on text query"""
if not query_text.strip():
return [], "Please enter a search query"
try:
# Choose data source
if data_source == "custom" and custom_image_emb is not None:
images = custom_images
descriptions = custom_descriptions
image_emb = custom_image_emb
else:
images = demo_data['image']
descriptions = demo_data['text']
image_emb = demo_image_emb
# Process the text query
inputs = processor(text=[query_text], return_tensors="pt", padding=True).to(device)
with torch.no_grad():
text_features = model.get_text_features(**inputs)
text_features = text_features / text_features.norm(dim=-1, keepdim=True)
# Calculate similarity scores
similarity = torch.mm(text_features, image_emb.T)
# Get top-k matches
values, indices = similarity[0].topk(min(top_k, len(images)))
results = []
for idx, score in zip(indices, values):
results.append((images[idx], f"Score: {score.item():.3f}\n{descriptions[idx]}"))
status = f"Found {len(results)} matches for: '{query_text}'"
return results, status
except Exception as e:
return [], f"Error during search: {str(e)}"
def search_similar_images(query_image, top_k=5, data_source="demo"):
"""Search similar images based on query image"""
if query_image is None:
return [], "Please provide a query image"
try:
# Choose data source
if data_source == "custom" and custom_image_emb is not None:
images = custom_images
descriptions = custom_descriptions
image_emb = custom_image_emb
else:
images = demo_data['image']
descriptions = demo_data['text']
image_emb = demo_image_emb
# Process the query image
inputs = processor(images=query_image, return_tensors="pt", padding=True).to(device)
with torch.no_grad():
image_features = model.get_image_features(**inputs)
image_features = image_features / image_features.norm(dim=-1, keepdim=True)
# Calculate similarity scores
similarity = torch.mm(image_features, image_emb.T)
# Get top-k matches
values, indices = similarity[0].topk(min(top_k, len(images)))
results = []
for idx, score in zip(indices, values):
results.append((images[idx], f"Score: {score.item():.3f}\n{descriptions[idx]}"))
status = f"Found {len(results)} similar images"
return results, status
except Exception as e:
return [], f"Error during search: {str(e)}"
def classify_image(image, labels_text):
"""Classify image with custom labels"""
if image is None:
return None, "Please provide an image"
if not labels_text.strip():
return None, "Please provide labels"
try:
labels = [label.strip() for label in labels_text.split('\n') if label.strip()]
if not labels:
return None, "Please provide valid labels"
# Prepare text prompts
text_prompts = [f"a photo of {label}" for label in labels]
inputs = processor(
text=text_prompts,
images=image,
return_tensors="pt",
padding=True,
).to(device)
with torch.no_grad():
outputs = model(**inputs)
logits_per_image = outputs.logits_per_image
probs = logits_per_image.softmax(dim=1)
# Create bar chart
probabilities = probs[0].cpu().numpy()
fig, ax = plt.subplots(figsize=(10, 6))
bars = ax.barh(labels, probabilities)
ax.set_xlabel('Probability')
ax.set_title('Zero-Shot Classification Results')
# Color bars based on probability
for i, bar in enumerate(bars):
bar.set_color(plt.cm.viridis(probabilities[i]))
plt.tight_layout()
# Create detailed results text
results_text = "Classification Results:\n\n"
sorted_results = sorted(zip(labels, probabilities), key=lambda x: x[1], reverse=True)
for label, prob in sorted_results:
results_text += f"{label}: {prob:.3f} ({prob*100:.1f}%)\n"
return fig, results_text
except Exception as e:
return None, f"Error during classification: {str(e)}"
def get_random_demo_images():
"""Get random images from current dataset"""
try:
if current_data_source == "custom" and custom_images:
images = custom_images
descriptions = custom_descriptions
else:
images = demo_data['image']
descriptions = demo_data['text']
if len(images) == 0:
return []
# Get random indices
indices = np.random.choice(len(images), min(6, len(images)), replace=False)
results = []
for idx in indices:
results.append((images[idx], f"Image {idx}: {descriptions[idx][:100]}..."))
return results
except Exception as e:
return []
def switch_data_source(choice):
"""Switch between demo and custom data source"""
global current_data_source
current_data_source = "demo" if choice == "Demo Dataset" else "custom"
if current_data_source == "custom" and not custom_images:
return "⚠️ Custom folder not loaded. Please load a custom folder first."
elif current_data_source == "custom":
return f"βœ… Switched to custom folder ({len(custom_images)} images)"
else:
return f"βœ… Switched to demo dataset ({len(demo_data)} images)"
# Initialize the model when the module loads
initialization_status = load_model_and_demo_data()
# Create Gradio interface
with gr.Blocks(title="AI Image Discovery Studio", theme=gr.themes.Soft()) as demo:
gr.Markdown("""
# πŸ–ΌοΈ AI Image Discovery Studio
Search images using natural language or find visually similar content with CLIP embeddings!
""")
# Status display
with gr.Row():
status_display = gr.Textbox(
value=initialization_status,
label="System Status",
interactive=False
)
# Data source selection and custom folder loading
with gr.Row():
with gr.Column(scale=1):
data_source_radio = gr.Radio(
["Demo Dataset", "Custom Folder"],
value="Demo Dataset",
label="Data Source"
)
folder_path_input = gr.Textbox(
label="Custom Folder Path",
placeholder="e.g., /path/to/your/images",
visible=False
)
load_folder_btn = gr.Button("Load Custom Folder", visible=False)
folder_status = gr.Textbox(label="Folder Status", visible=False, interactive=False)
with gr.Column(scale=2):
source_status = gr.Textbox(
value=f"βœ… Using demo dataset ({len(demo_data)} images)",
label="Current Data Source",
interactive=False
)
# Show/hide custom folder controls based on selection
def toggle_folder_controls(choice):
visible = choice == "Custom Folder"
return (
gr.update(visible=visible), # folder_path_input
gr.update(visible=visible), # load_folder_btn
gr.update(visible=visible) # folder_status
)
data_source_radio.change(
toggle_folder_controls,
inputs=[data_source_radio],
outputs=[folder_path_input, load_folder_btn, folder_status]
)
# Update data source status
data_source_radio.change(
switch_data_source,
inputs=[data_source_radio],
outputs=[source_status]
)
# Load custom folder
load_folder_btn.click(
load_custom_folder,
inputs=[folder_path_input],
outputs=[folder_status]
)
# Main tabs
with gr.Tabs():
# Text to Image Search Tab
with gr.TabItem("πŸ”€ Text to Image Search"):
gr.Markdown("Enter a text description to find matching images")
with gr.Row():
with gr.Column():
text_query = gr.Textbox(
label="Search Query",
placeholder="e.g., 'Dog running on grass', 'Beautiful sunset over mountains'"
)
text_top_k = gr.Slider(1, 10, value=5, step=1, label="Number of Results")
text_search_btn = gr.Button("πŸ” Search Images", variant="primary")
with gr.Column():
text_search_status = gr.Textbox(label="Search Status", interactive=False)
text_results = gr.Gallery(
label="Search Results",
show_label=True,
elem_id="text_search_gallery",
columns=5,
rows=1,
height="auto"
)
# Connect text search
text_search_btn.click(
lambda query, top_k, source: search_images_by_text(
query, top_k, "custom" if source == "Custom Folder" else "demo"
),
inputs=[text_query, text_top_k, data_source_radio],
outputs=[text_results, text_search_status]
)
# Image to Image Search Tab
with gr.TabItem("πŸ–ΌοΈ Image to Image Search"):
gr.Markdown("Upload an image to find visually similar ones")
with gr.Row():
with gr.Column():
query_image = gr.Image(label="Query Image", type="pil")
image_top_k = gr.Slider(1, 10, value=5, step=1, label="Number of Results")
image_search_btn = gr.Button("πŸ” Find Similar Images", variant="primary")
with gr.Column():
image_search_status = gr.Textbox(label="Search Status", interactive=False)
image_results = gr.Gallery(
label="Similar Images",
show_label=True,
elem_id="image_search_gallery",
columns=5,
rows=1,
height="auto"
)
# Connect image search
image_search_btn.click(
lambda img, top_k, source: search_similar_images(
img, top_k, "custom" if source == "Custom Folder" else "demo"
),
inputs=[query_image, image_top_k, data_source_radio],
outputs=[image_results, image_search_status]
)
# Zero-Shot Classification Tab
with gr.TabItem("🏷️ Zero-Shot Classification"):
gr.Markdown("Classify an image with custom labels using CLIP")
with gr.Row():
with gr.Column():
classify_image_input = gr.Image(label="Image to Classify", type="pil")
labels_input = gr.Textbox(
label="Classification Labels (one per line)",
value="cat\ndog\ncar\nbird\nflower",
lines=5
)
classify_btn = gr.Button("πŸ” Classify Image", variant="primary")
with gr.Column():
classification_results = gr.Textbox(
label="Detailed Results",
lines=10,
interactive=False
)
classification_plot = gr.Plot(label="Classification Results")
# Connect classification
classify_btn.click(
classify_image,
inputs=[classify_image_input, labels_input],
outputs=[classification_plot, classification_results]
)
# Dataset Explorer Tab
with gr.TabItem("πŸ“Š Dataset Explorer"):
gr.Markdown("Browse through the dataset images")
with gr.Row():
random_sample_btn = gr.Button("🎲 Show Random Sample", variant="primary")
explorer_gallery = gr.Gallery(
label="Dataset Sample",
show_label=True,
elem_id="explorer_gallery",
columns=3,
rows=2,
height="auto"
)
# Connect random sampling
random_sample_btn.click(
get_random_demo_images,
outputs=[explorer_gallery]
)
# Launch the app
if __name__ == "__main__":
demo.launch()