import gradio as gr import torch from transformers import CLIPProcessor, CLIPModel from datasets import load_dataset from PIL import Image import requests import matplotlib.pyplot as plt import os import glob from pathlib import Path import numpy as np import io import base64 # Global variables for model and data model = None processor = None device = None demo_data = None demo_text_emb = None demo_image_emb = None # Custom folder data custom_images = [] custom_descriptions = [] custom_paths = [] custom_image_emb = None current_data_source = "demo" def load_model_and_demo_data(): """Load CLIP model and demo dataset""" global model, processor, device, demo_data, demo_text_emb, demo_image_emb try: # Load dataset demo_data = load_dataset("jamescalam/image-text-demo", split="train") # Load model model_id = "openai/clip-vit-base-patch32" processor = CLIPProcessor.from_pretrained(model_id) model = CLIPModel.from_pretrained(model_id) # Move to device device = 'cuda' if torch.cuda.is_available() else 'cpu' model.to(device) # Pre-compute image embeddings text = demo_data['text'] images = demo_data['image'] inputs = processor( text=text, images=images, return_tensors="pt", padding=True, ).to(device) outputs = model(**inputs) # Normalize embeddings demo_text_emb = outputs.text_embeds demo_text_emb = demo_text_emb / torch.norm(demo_text_emb, dim=1, keepdim=True) demo_image_emb = outputs.image_embeds demo_image_emb = demo_image_emb / torch.norm(demo_image_emb, dim=1, keepdim=True) return f"✅ Model loaded successfully on {device.upper()}. Demo dataset: {len(demo_data)} images." except Exception as e: return f"❌ Error loading model: {str(e)}" def load_custom_folder(folder_path): """Load images from a custom folder""" global custom_images, custom_descriptions, custom_paths, custom_image_emb, current_data_source if not folder_path or not os.path.exists(folder_path): return "❌ Invalid folder path" try: supported_formats = ['*.jpg', '*.jpeg', '*.png', '*.bmp', '*.gif', '*.tiff'] image_paths = [] # Get all image files from the folder for format_type in supported_formats: image_paths.extend(glob.glob(os.path.join(folder_path, format_type))) image_paths.extend(glob.glob(os.path.join(folder_path, format_type.upper()))) # Also search in subdirectories for format_type in supported_formats: image_paths.extend(glob.glob(os.path.join(folder_path, '**', format_type), recursive=True)) image_paths.extend(glob.glob(os.path.join(folder_path, '**', format_type.upper()), recursive=True)) # Remove duplicates and sort image_paths = sorted(list(set(image_paths))) if not image_paths: return "❌ No valid images found in the specified folder" # Load images custom_images.clear() custom_descriptions.clear() custom_paths.clear() for img_path in image_paths[:100]: # Limit to 100 images for demo try: img = Image.open(img_path).convert('RGB') custom_images.append(img) filename = Path(img_path).stem custom_descriptions.append(f"Image: {filename}") custom_paths.append(img_path) except Exception as e: continue if not custom_images: return "❌ No valid images could be loaded" # Compute embeddings custom_image_emb = compute_custom_embeddings(custom_images, custom_descriptions) current_data_source = "custom" return f"✅ Loaded {len(custom_images)} images from custom folder" except Exception as e: return f"❌ Error loading custom folder: {str(e)}" def compute_custom_embeddings(images, descriptions): """Compute embeddings for custom images""" try: batch_size = 8 all_image_embeddings = [] for i in range(0, len(images), batch_size): batch_images = images[i:i+batch_size] batch_texts = descriptions[i:i+batch_size] inputs = processor( text=batch_texts, images=batch_images, return_tensors="pt", padding=True, ).to(device) with torch.no_grad(): outputs = model(**inputs) image_emb = outputs.image_embeds image_emb = image_emb / torch.norm(image_emb, dim=1, keepdim=True) all_image_embeddings.append(image_emb.cpu()) return torch.cat(all_image_embeddings, dim=0).to(device) except Exception as e: print(f"Error computing embeddings: {str(e)}") return None def search_images_by_text(query_text, top_k=5, data_source="demo"): """Search images based on text query""" if not query_text.strip(): return [], "Please enter a search query" try: # Choose data source if data_source == "custom" and custom_image_emb is not None: images = custom_images descriptions = custom_descriptions image_emb = custom_image_emb else: images = demo_data['image'] descriptions = demo_data['text'] image_emb = demo_image_emb # Process the text query inputs = processor(text=[query_text], return_tensors="pt", padding=True).to(device) with torch.no_grad(): text_features = model.get_text_features(**inputs) text_features = text_features / text_features.norm(dim=-1, keepdim=True) # Calculate similarity scores similarity = torch.mm(text_features, image_emb.T) # Get top-k matches values, indices = similarity[0].topk(min(top_k, len(images))) results = [] for idx, score in zip(indices, values): results.append((images[idx], f"Score: {score.item():.3f}\n{descriptions[idx]}")) status = f"Found {len(results)} matches for: '{query_text}'" return results, status except Exception as e: return [], f"Error during search: {str(e)}" def search_similar_images(query_image, top_k=5, data_source="demo"): """Search similar images based on query image""" if query_image is None: return [], "Please provide a query image" try: # Choose data source if data_source == "custom" and custom_image_emb is not None: images = custom_images descriptions = custom_descriptions image_emb = custom_image_emb else: images = demo_data['image'] descriptions = demo_data['text'] image_emb = demo_image_emb # Process the query image inputs = processor(images=query_image, return_tensors="pt", padding=True).to(device) with torch.no_grad(): image_features = model.get_image_features(**inputs) image_features = image_features / image_features.norm(dim=-1, keepdim=True) # Calculate similarity scores similarity = torch.mm(image_features, image_emb.T) # Get top-k matches values, indices = similarity[0].topk(min(top_k, len(images))) results = [] for idx, score in zip(indices, values): results.append((images[idx], f"Score: {score.item():.3f}\n{descriptions[idx]}")) status = f"Found {len(results)} similar images" return results, status except Exception as e: return [], f"Error during search: {str(e)}" def classify_image(image, labels_text): """Classify image with custom labels""" if image is None: return None, "Please provide an image" if not labels_text.strip(): return None, "Please provide labels" try: labels = [label.strip() for label in labels_text.split('\n') if label.strip()] if not labels: return None, "Please provide valid labels" # Prepare text prompts text_prompts = [f"a photo of {label}" for label in labels] inputs = processor( text=text_prompts, images=image, return_tensors="pt", padding=True, ).to(device) with torch.no_grad(): outputs = model(**inputs) logits_per_image = outputs.logits_per_image probs = logits_per_image.softmax(dim=1) # Create bar chart probabilities = probs[0].cpu().numpy() fig, ax = plt.subplots(figsize=(10, 6)) bars = ax.barh(labels, probabilities) ax.set_xlabel('Probability') ax.set_title('Zero-Shot Classification Results') # Color bars based on probability for i, bar in enumerate(bars): bar.set_color(plt.cm.viridis(probabilities[i])) plt.tight_layout() # Create detailed results text results_text = "Classification Results:\n\n" sorted_results = sorted(zip(labels, probabilities), key=lambda x: x[1], reverse=True) for label, prob in sorted_results: results_text += f"{label}: {prob:.3f} ({prob*100:.1f}%)\n" return fig, results_text except Exception as e: return None, f"Error during classification: {str(e)}" def get_random_demo_images(): """Get random images from current dataset""" try: if current_data_source == "custom" and custom_images: images = custom_images descriptions = custom_descriptions else: images = demo_data['image'] descriptions = demo_data['text'] if len(images) == 0: return [] # Get random indices indices = np.random.choice(len(images), min(6, len(images)), replace=False) results = [] for idx in indices: results.append((images[idx], f"Image {idx}: {descriptions[idx][:100]}...")) return results except Exception as e: return [] def switch_data_source(choice): """Switch between demo and custom data source""" global current_data_source current_data_source = "demo" if choice == "Demo Dataset" else "custom" if current_data_source == "custom" and not custom_images: return "⚠️ Custom folder not loaded. Please load a custom folder first." elif current_data_source == "custom": return f"✅ Switched to custom folder ({len(custom_images)} images)" else: return f"✅ Switched to demo dataset ({len(demo_data)} images)" # Initialize the model when the module loads initialization_status = load_model_and_demo_data() # Create Gradio interface with gr.Blocks(title="AI Image Discovery Studio", theme=gr.themes.Soft()) as demo: gr.Markdown(""" # 🖼️ AI Image Discovery Studio Search images using natural language or find visually similar content with CLIP embeddings! """) # Status display with gr.Row(): status_display = gr.Textbox( value=initialization_status, label="System Status", interactive=False ) # Data source selection and custom folder loading with gr.Row(): with gr.Column(scale=1): data_source_radio = gr.Radio( ["Demo Dataset", "Custom Folder"], value="Demo Dataset", label="Data Source" ) folder_path_input = gr.Textbox( label="Custom Folder Path", placeholder="e.g., /path/to/your/images", visible=False ) load_folder_btn = gr.Button("Load Custom Folder", visible=False) folder_status = gr.Textbox(label="Folder Status", visible=False, interactive=False) with gr.Column(scale=2): source_status = gr.Textbox( value=f"✅ Using demo dataset ({len(demo_data)} images)", label="Current Data Source", interactive=False ) # Show/hide custom folder controls based on selection def toggle_folder_controls(choice): visible = choice == "Custom Folder" return ( gr.update(visible=visible), # folder_path_input gr.update(visible=visible), # load_folder_btn gr.update(visible=visible) # folder_status ) data_source_radio.change( toggle_folder_controls, inputs=[data_source_radio], outputs=[folder_path_input, load_folder_btn, folder_status] ) # Update data source status data_source_radio.change( switch_data_source, inputs=[data_source_radio], outputs=[source_status] ) # Load custom folder load_folder_btn.click( load_custom_folder, inputs=[folder_path_input], outputs=[folder_status] ) # Main tabs with gr.Tabs(): # Text to Image Search Tab with gr.TabItem("🔤 Text to Image Search"): gr.Markdown("Enter a text description to find matching images") with gr.Row(): with gr.Column(): text_query = gr.Textbox( label="Search Query", placeholder="e.g., 'Dog running on grass', 'Beautiful sunset over mountains'" ) text_top_k = gr.Slider(1, 10, value=5, step=1, label="Number of Results") text_search_btn = gr.Button("🔍 Search Images", variant="primary") with gr.Column(): text_search_status = gr.Textbox(label="Search Status", interactive=False) text_results = gr.Gallery( label="Search Results", show_label=True, elem_id="text_search_gallery", columns=5, rows=1, height="auto" ) # Connect text search text_search_btn.click( lambda query, top_k, source: search_images_by_text( query, top_k, "custom" if source == "Custom Folder" else "demo" ), inputs=[text_query, text_top_k, data_source_radio], outputs=[text_results, text_search_status] ) # Image to Image Search Tab with gr.TabItem("🖼️ Image to Image Search"): gr.Markdown("Upload an image to find visually similar ones") with gr.Row(): with gr.Column(): query_image = gr.Image(label="Query Image", type="pil") image_top_k = gr.Slider(1, 10, value=5, step=1, label="Number of Results") image_search_btn = gr.Button("🔍 Find Similar Images", variant="primary") with gr.Column(): image_search_status = gr.Textbox(label="Search Status", interactive=False) image_results = gr.Gallery( label="Similar Images", show_label=True, elem_id="image_search_gallery", columns=5, rows=1, height="auto" ) # Connect image search image_search_btn.click( lambda img, top_k, source: search_similar_images( img, top_k, "custom" if source == "Custom Folder" else "demo" ), inputs=[query_image, image_top_k, data_source_radio], outputs=[image_results, image_search_status] ) # Zero-Shot Classification Tab with gr.TabItem("🏷️ Zero-Shot Classification"): gr.Markdown("Classify an image with custom labels using CLIP") with gr.Row(): with gr.Column(): classify_image_input = gr.Image(label="Image to Classify", type="pil") labels_input = gr.Textbox( label="Classification Labels (one per line)", value="cat\ndog\ncar\nbird\nflower", lines=5 ) classify_btn = gr.Button("🔍 Classify Image", variant="primary") with gr.Column(): classification_results = gr.Textbox( label="Detailed Results", lines=10, interactive=False ) classification_plot = gr.Plot(label="Classification Results") # Connect classification classify_btn.click( classify_image, inputs=[classify_image_input, labels_input], outputs=[classification_plot, classification_results] ) # Dataset Explorer Tab with gr.TabItem("📊 Dataset Explorer"): gr.Markdown("Browse through the dataset images") with gr.Row(): random_sample_btn = gr.Button("🎲 Show Random Sample", variant="primary") explorer_gallery = gr.Gallery( label="Dataset Sample", show_label=True, elem_id="explorer_gallery", columns=3, rows=2, height="auto" ) # Connect random sampling random_sample_btn.click( get_random_demo_images, outputs=[explorer_gallery] ) # Launch the app if __name__ == "__main__": demo.launch()