|
import gradio as gr |
|
import torch |
|
from transformers import CLIPProcessor, CLIPModel |
|
from datasets import load_dataset |
|
from PIL import Image |
|
import requests |
|
import matplotlib.pyplot as plt |
|
import os |
|
import glob |
|
from pathlib import Path |
|
import numpy as np |
|
import io |
|
import base64 |
|
|
|
|
|
model = None |
|
processor = None |
|
device = None |
|
demo_data = None |
|
demo_text_emb = None |
|
demo_image_emb = None |
|
|
|
|
|
custom_images = [] |
|
custom_descriptions = [] |
|
custom_paths = [] |
|
custom_image_emb = None |
|
current_data_source = "demo" |
|
|
|
def load_model_and_demo_data(): |
|
"""Load CLIP model and demo dataset""" |
|
global model, processor, device, demo_data, demo_text_emb, demo_image_emb |
|
|
|
try: |
|
|
|
demo_data = load_dataset("jamescalam/image-text-demo", split="train") |
|
|
|
|
|
model_id = "openai/clip-vit-base-patch32" |
|
processor = CLIPProcessor.from_pretrained(model_id) |
|
model = CLIPModel.from_pretrained(model_id) |
|
|
|
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
model.to(device) |
|
|
|
|
|
text = demo_data['text'] |
|
images = demo_data['image'] |
|
|
|
inputs = processor( |
|
text=text, |
|
images=images, |
|
return_tensors="pt", |
|
padding=True, |
|
).to(device) |
|
|
|
outputs = model(**inputs) |
|
|
|
|
|
demo_text_emb = outputs.text_embeds |
|
demo_text_emb = demo_text_emb / torch.norm(demo_text_emb, dim=1, keepdim=True) |
|
|
|
demo_image_emb = outputs.image_embeds |
|
demo_image_emb = demo_image_emb / torch.norm(demo_image_emb, dim=1, keepdim=True) |
|
|
|
return f"β
Model loaded successfully on {device.upper()}. Demo dataset: {len(demo_data)} images." |
|
|
|
except Exception as e: |
|
return f"β Error loading model: {str(e)}" |
|
|
|
def load_custom_folder(folder_path): |
|
"""Load images from a custom folder""" |
|
global custom_images, custom_descriptions, custom_paths, custom_image_emb, current_data_source |
|
|
|
if not folder_path or not os.path.exists(folder_path): |
|
return "β Invalid folder path" |
|
|
|
try: |
|
supported_formats = ['*.jpg', '*.jpeg', '*.png', '*.bmp', '*.gif', '*.tiff'] |
|
image_paths = [] |
|
|
|
|
|
for format_type in supported_formats: |
|
image_paths.extend(glob.glob(os.path.join(folder_path, format_type))) |
|
image_paths.extend(glob.glob(os.path.join(folder_path, format_type.upper()))) |
|
|
|
|
|
for format_type in supported_formats: |
|
image_paths.extend(glob.glob(os.path.join(folder_path, '**', format_type), recursive=True)) |
|
image_paths.extend(glob.glob(os.path.join(folder_path, '**', format_type.upper()), recursive=True)) |
|
|
|
|
|
image_paths = sorted(list(set(image_paths))) |
|
|
|
if not image_paths: |
|
return "β No valid images found in the specified folder" |
|
|
|
|
|
custom_images.clear() |
|
custom_descriptions.clear() |
|
custom_paths.clear() |
|
|
|
for img_path in image_paths[:100]: |
|
try: |
|
img = Image.open(img_path).convert('RGB') |
|
custom_images.append(img) |
|
filename = Path(img_path).stem |
|
custom_descriptions.append(f"Image: {filename}") |
|
custom_paths.append(img_path) |
|
except Exception as e: |
|
continue |
|
|
|
if not custom_images: |
|
return "β No valid images could be loaded" |
|
|
|
|
|
custom_image_emb = compute_custom_embeddings(custom_images, custom_descriptions) |
|
current_data_source = "custom" |
|
|
|
return f"β
Loaded {len(custom_images)} images from custom folder" |
|
|
|
except Exception as e: |
|
return f"β Error loading custom folder: {str(e)}" |
|
|
|
def compute_custom_embeddings(images, descriptions): |
|
"""Compute embeddings for custom images""" |
|
try: |
|
batch_size = 8 |
|
all_image_embeddings = [] |
|
|
|
for i in range(0, len(images), batch_size): |
|
batch_images = images[i:i+batch_size] |
|
batch_texts = descriptions[i:i+batch_size] |
|
|
|
inputs = processor( |
|
text=batch_texts, |
|
images=batch_images, |
|
return_tensors="pt", |
|
padding=True, |
|
).to(device) |
|
|
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
image_emb = outputs.image_embeds |
|
image_emb = image_emb / torch.norm(image_emb, dim=1, keepdim=True) |
|
all_image_embeddings.append(image_emb.cpu()) |
|
|
|
return torch.cat(all_image_embeddings, dim=0).to(device) |
|
|
|
except Exception as e: |
|
print(f"Error computing embeddings: {str(e)}") |
|
return None |
|
|
|
def search_images_by_text(query_text, top_k=5, data_source="demo"): |
|
"""Search images based on text query""" |
|
if not query_text.strip(): |
|
return [], "Please enter a search query" |
|
|
|
try: |
|
|
|
if data_source == "custom" and custom_image_emb is not None: |
|
images = custom_images |
|
descriptions = custom_descriptions |
|
image_emb = custom_image_emb |
|
else: |
|
images = demo_data['image'] |
|
descriptions = demo_data['text'] |
|
image_emb = demo_image_emb |
|
|
|
|
|
inputs = processor(text=[query_text], return_tensors="pt", padding=True).to(device) |
|
|
|
with torch.no_grad(): |
|
text_features = model.get_text_features(**inputs) |
|
text_features = text_features / text_features.norm(dim=-1, keepdim=True) |
|
|
|
|
|
similarity = torch.mm(text_features, image_emb.T) |
|
|
|
|
|
values, indices = similarity[0].topk(min(top_k, len(images))) |
|
|
|
results = [] |
|
for idx, score in zip(indices, values): |
|
results.append((images[idx], f"Score: {score.item():.3f}\n{descriptions[idx]}")) |
|
|
|
status = f"Found {len(results)} matches for: '{query_text}'" |
|
return results, status |
|
|
|
except Exception as e: |
|
return [], f"Error during search: {str(e)}" |
|
|
|
def search_similar_images(query_image, top_k=5, data_source="demo"): |
|
"""Search similar images based on query image""" |
|
if query_image is None: |
|
return [], "Please provide a query image" |
|
|
|
try: |
|
|
|
if data_source == "custom" and custom_image_emb is not None: |
|
images = custom_images |
|
descriptions = custom_descriptions |
|
image_emb = custom_image_emb |
|
else: |
|
images = demo_data['image'] |
|
descriptions = demo_data['text'] |
|
image_emb = demo_image_emb |
|
|
|
|
|
inputs = processor(images=query_image, return_tensors="pt", padding=True).to(device) |
|
|
|
with torch.no_grad(): |
|
image_features = model.get_image_features(**inputs) |
|
image_features = image_features / image_features.norm(dim=-1, keepdim=True) |
|
|
|
|
|
similarity = torch.mm(image_features, image_emb.T) |
|
|
|
|
|
values, indices = similarity[0].topk(min(top_k, len(images))) |
|
|
|
results = [] |
|
for idx, score in zip(indices, values): |
|
results.append((images[idx], f"Score: {score.item():.3f}\n{descriptions[idx]}")) |
|
|
|
status = f"Found {len(results)} similar images" |
|
return results, status |
|
|
|
except Exception as e: |
|
return [], f"Error during search: {str(e)}" |
|
|
|
def classify_image(image, labels_text): |
|
"""Classify image with custom labels""" |
|
if image is None: |
|
return None, "Please provide an image" |
|
|
|
if not labels_text.strip(): |
|
return None, "Please provide labels" |
|
|
|
try: |
|
labels = [label.strip() for label in labels_text.split('\n') if label.strip()] |
|
|
|
if not labels: |
|
return None, "Please provide valid labels" |
|
|
|
|
|
text_prompts = [f"a photo of {label}" for label in labels] |
|
|
|
inputs = processor( |
|
text=text_prompts, |
|
images=image, |
|
return_tensors="pt", |
|
padding=True, |
|
).to(device) |
|
|
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
logits_per_image = outputs.logits_per_image |
|
probs = logits_per_image.softmax(dim=1) |
|
|
|
|
|
probabilities = probs[0].cpu().numpy() |
|
|
|
fig, ax = plt.subplots(figsize=(10, 6)) |
|
bars = ax.barh(labels, probabilities) |
|
ax.set_xlabel('Probability') |
|
ax.set_title('Zero-Shot Classification Results') |
|
|
|
|
|
for i, bar in enumerate(bars): |
|
bar.set_color(plt.cm.viridis(probabilities[i])) |
|
|
|
plt.tight_layout() |
|
|
|
|
|
results_text = "Classification Results:\n\n" |
|
sorted_results = sorted(zip(labels, probabilities), key=lambda x: x[1], reverse=True) |
|
|
|
for label, prob in sorted_results: |
|
results_text += f"{label}: {prob:.3f} ({prob*100:.1f}%)\n" |
|
|
|
return fig, results_text |
|
|
|
except Exception as e: |
|
return None, f"Error during classification: {str(e)}" |
|
|
|
def get_random_demo_images(): |
|
"""Get random images from current dataset""" |
|
try: |
|
if current_data_source == "custom" and custom_images: |
|
images = custom_images |
|
descriptions = custom_descriptions |
|
else: |
|
images = demo_data['image'] |
|
descriptions = demo_data['text'] |
|
|
|
if len(images) == 0: |
|
return [] |
|
|
|
|
|
indices = np.random.choice(len(images), min(6, len(images)), replace=False) |
|
|
|
results = [] |
|
for idx in indices: |
|
results.append((images[idx], f"Image {idx}: {descriptions[idx][:100]}...")) |
|
|
|
return results |
|
|
|
except Exception as e: |
|
return [] |
|
|
|
def switch_data_source(choice): |
|
"""Switch between demo and custom data source""" |
|
global current_data_source |
|
current_data_source = "demo" if choice == "Demo Dataset" else "custom" |
|
|
|
if current_data_source == "custom" and not custom_images: |
|
return "β οΈ Custom folder not loaded. Please load a custom folder first." |
|
elif current_data_source == "custom": |
|
return f"β
Switched to custom folder ({len(custom_images)} images)" |
|
else: |
|
return f"β
Switched to demo dataset ({len(demo_data)} images)" |
|
|
|
|
|
initialization_status = load_model_and_demo_data() |
|
|
|
|
|
with gr.Blocks(title="AI Image Discovery Studio", theme=gr.themes.Soft()) as demo: |
|
gr.Markdown(""" |
|
# πΌοΈ AI Image Discovery Studio |
|
|
|
Search images using natural language or find visually similar content with CLIP embeddings! |
|
""") |
|
|
|
|
|
with gr.Row(): |
|
status_display = gr.Textbox( |
|
value=initialization_status, |
|
label="System Status", |
|
interactive=False |
|
) |
|
|
|
|
|
with gr.Row(): |
|
with gr.Column(scale=1): |
|
data_source_radio = gr.Radio( |
|
["Demo Dataset", "Custom Folder"], |
|
value="Demo Dataset", |
|
label="Data Source" |
|
) |
|
|
|
folder_path_input = gr.Textbox( |
|
label="Custom Folder Path", |
|
placeholder="e.g., /path/to/your/images", |
|
visible=False |
|
) |
|
|
|
load_folder_btn = gr.Button("Load Custom Folder", visible=False) |
|
folder_status = gr.Textbox(label="Folder Status", visible=False, interactive=False) |
|
|
|
with gr.Column(scale=2): |
|
source_status = gr.Textbox( |
|
value=f"β
Using demo dataset ({len(demo_data)} images)", |
|
label="Current Data Source", |
|
interactive=False |
|
) |
|
|
|
|
|
def toggle_folder_controls(choice): |
|
visible = choice == "Custom Folder" |
|
return ( |
|
gr.update(visible=visible), |
|
gr.update(visible=visible), |
|
gr.update(visible=visible) |
|
) |
|
|
|
data_source_radio.change( |
|
toggle_folder_controls, |
|
inputs=[data_source_radio], |
|
outputs=[folder_path_input, load_folder_btn, folder_status] |
|
) |
|
|
|
|
|
data_source_radio.change( |
|
switch_data_source, |
|
inputs=[data_source_radio], |
|
outputs=[source_status] |
|
) |
|
|
|
|
|
load_folder_btn.click( |
|
load_custom_folder, |
|
inputs=[folder_path_input], |
|
outputs=[folder_status] |
|
) |
|
|
|
|
|
with gr.Tabs(): |
|
|
|
with gr.TabItem("π€ Text to Image Search"): |
|
gr.Markdown("Enter a text description to find matching images") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
text_query = gr.Textbox( |
|
label="Search Query", |
|
placeholder="e.g., 'Dog running on grass', 'Beautiful sunset over mountains'" |
|
) |
|
text_top_k = gr.Slider(1, 10, value=5, step=1, label="Number of Results") |
|
text_search_btn = gr.Button("π Search Images", variant="primary") |
|
|
|
with gr.Column(): |
|
text_search_status = gr.Textbox(label="Search Status", interactive=False) |
|
|
|
text_results = gr.Gallery( |
|
label="Search Results", |
|
show_label=True, |
|
elem_id="text_search_gallery", |
|
columns=5, |
|
rows=1, |
|
height="auto" |
|
) |
|
|
|
|
|
text_search_btn.click( |
|
lambda query, top_k, source: search_images_by_text( |
|
query, top_k, "custom" if source == "Custom Folder" else "demo" |
|
), |
|
inputs=[text_query, text_top_k, data_source_radio], |
|
outputs=[text_results, text_search_status] |
|
) |
|
|
|
|
|
with gr.TabItem("πΌοΈ Image to Image Search"): |
|
gr.Markdown("Upload an image to find visually similar ones") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
query_image = gr.Image(label="Query Image", type="pil") |
|
image_top_k = gr.Slider(1, 10, value=5, step=1, label="Number of Results") |
|
image_search_btn = gr.Button("π Find Similar Images", variant="primary") |
|
|
|
with gr.Column(): |
|
image_search_status = gr.Textbox(label="Search Status", interactive=False) |
|
|
|
image_results = gr.Gallery( |
|
label="Similar Images", |
|
show_label=True, |
|
elem_id="image_search_gallery", |
|
columns=5, |
|
rows=1, |
|
height="auto" |
|
) |
|
|
|
|
|
image_search_btn.click( |
|
lambda img, top_k, source: search_similar_images( |
|
img, top_k, "custom" if source == "Custom Folder" else "demo" |
|
), |
|
inputs=[query_image, image_top_k, data_source_radio], |
|
outputs=[image_results, image_search_status] |
|
) |
|
|
|
|
|
with gr.TabItem("π·οΈ Zero-Shot Classification"): |
|
gr.Markdown("Classify an image with custom labels using CLIP") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
classify_image_input = gr.Image(label="Image to Classify", type="pil") |
|
labels_input = gr.Textbox( |
|
label="Classification Labels (one per line)", |
|
value="cat\ndog\ncar\nbird\nflower", |
|
lines=5 |
|
) |
|
classify_btn = gr.Button("π Classify Image", variant="primary") |
|
|
|
with gr.Column(): |
|
classification_results = gr.Textbox( |
|
label="Detailed Results", |
|
lines=10, |
|
interactive=False |
|
) |
|
|
|
classification_plot = gr.Plot(label="Classification Results") |
|
|
|
|
|
classify_btn.click( |
|
classify_image, |
|
inputs=[classify_image_input, labels_input], |
|
outputs=[classification_plot, classification_results] |
|
) |
|
|
|
|
|
with gr.TabItem("π Dataset Explorer"): |
|
gr.Markdown("Browse through the dataset images") |
|
|
|
with gr.Row(): |
|
random_sample_btn = gr.Button("π² Show Random Sample", variant="primary") |
|
|
|
explorer_gallery = gr.Gallery( |
|
label="Dataset Sample", |
|
show_label=True, |
|
elem_id="explorer_gallery", |
|
columns=3, |
|
rows=2, |
|
height="auto" |
|
) |
|
|
|
|
|
random_sample_btn.click( |
|
get_random_demo_images, |
|
outputs=[explorer_gallery] |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
demo.launch() |