codemaker2015 commited on
Commit
dfbff50
Β·
verified Β·
1 Parent(s): 763cbc5

Upload gradio_app.py

Browse files
Files changed (1) hide show
  1. gradio_app.py +519 -0
gradio_app.py ADDED
@@ -0,0 +1,519 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from transformers import CLIPProcessor, CLIPModel
4
+ from datasets import load_dataset
5
+ from PIL import Image
6
+ import requests
7
+ import matplotlib.pyplot as plt
8
+ import os
9
+ import glob
10
+ from pathlib import Path
11
+ import numpy as np
12
+ import io
13
+ import base64
14
+
15
+ # Global variables for model and data
16
+ model = None
17
+ processor = None
18
+ device = None
19
+ demo_data = None
20
+ demo_text_emb = None
21
+ demo_image_emb = None
22
+
23
+ # Custom folder data
24
+ custom_images = []
25
+ custom_descriptions = []
26
+ custom_paths = []
27
+ custom_image_emb = None
28
+ current_data_source = "demo"
29
+
30
+ def load_model_and_demo_data():
31
+ """Load CLIP model and demo dataset"""
32
+ global model, processor, device, demo_data, demo_text_emb, demo_image_emb
33
+
34
+ try:
35
+ # Load dataset
36
+ demo_data = load_dataset("jamescalam/image-text-demo", split="train")
37
+
38
+ # Load model
39
+ model_id = "openai/clip-vit-base-patch32"
40
+ processor = CLIPProcessor.from_pretrained(model_id)
41
+ model = CLIPModel.from_pretrained(model_id)
42
+
43
+ # Move to device
44
+ device = 'cuda' if torch.cuda.is_available() else 'cpu'
45
+ model.to(device)
46
+
47
+ # Pre-compute image embeddings
48
+ text = demo_data['text']
49
+ images = demo_data['image']
50
+
51
+ inputs = processor(
52
+ text=text,
53
+ images=images,
54
+ return_tensors="pt",
55
+ padding=True,
56
+ ).to(device)
57
+
58
+ outputs = model(**inputs)
59
+
60
+ # Normalize embeddings
61
+ demo_text_emb = outputs.text_embeds
62
+ demo_text_emb = demo_text_emb / torch.norm(demo_text_emb, dim=1, keepdim=True)
63
+
64
+ demo_image_emb = outputs.image_embeds
65
+ demo_image_emb = demo_image_emb / torch.norm(demo_image_emb, dim=1, keepdim=True)
66
+
67
+ return f"βœ… Model loaded successfully on {device.upper()}. Demo dataset: {len(demo_data)} images."
68
+
69
+ except Exception as e:
70
+ return f"❌ Error loading model: {str(e)}"
71
+
72
+ def load_custom_folder(folder_path):
73
+ """Load images from a custom folder"""
74
+ global custom_images, custom_descriptions, custom_paths, custom_image_emb, current_data_source
75
+
76
+ if not folder_path or not os.path.exists(folder_path):
77
+ return "❌ Invalid folder path"
78
+
79
+ try:
80
+ supported_formats = ['*.jpg', '*.jpeg', '*.png', '*.bmp', '*.gif', '*.tiff']
81
+ image_paths = []
82
+
83
+ # Get all image files from the folder
84
+ for format_type in supported_formats:
85
+ image_paths.extend(glob.glob(os.path.join(folder_path, format_type)))
86
+ image_paths.extend(glob.glob(os.path.join(folder_path, format_type.upper())))
87
+
88
+ # Also search in subdirectories
89
+ for format_type in supported_formats:
90
+ image_paths.extend(glob.glob(os.path.join(folder_path, '**', format_type), recursive=True))
91
+ image_paths.extend(glob.glob(os.path.join(folder_path, '**', format_type.upper()), recursive=True))
92
+
93
+ # Remove duplicates and sort
94
+ image_paths = sorted(list(set(image_paths)))
95
+
96
+ if not image_paths:
97
+ return "❌ No valid images found in the specified folder"
98
+
99
+ # Load images
100
+ custom_images.clear()
101
+ custom_descriptions.clear()
102
+ custom_paths.clear()
103
+
104
+ for img_path in image_paths[:100]: # Limit to 100 images for demo
105
+ try:
106
+ img = Image.open(img_path).convert('RGB')
107
+ custom_images.append(img)
108
+ filename = Path(img_path).stem
109
+ custom_descriptions.append(f"Image: {filename}")
110
+ custom_paths.append(img_path)
111
+ except Exception as e:
112
+ continue
113
+
114
+ if not custom_images:
115
+ return "❌ No valid images could be loaded"
116
+
117
+ # Compute embeddings
118
+ custom_image_emb = compute_custom_embeddings(custom_images, custom_descriptions)
119
+ current_data_source = "custom"
120
+
121
+ return f"βœ… Loaded {len(custom_images)} images from custom folder"
122
+
123
+ except Exception as e:
124
+ return f"❌ Error loading custom folder: {str(e)}"
125
+
126
+ def compute_custom_embeddings(images, descriptions):
127
+ """Compute embeddings for custom images"""
128
+ try:
129
+ batch_size = 8
130
+ all_image_embeddings = []
131
+
132
+ for i in range(0, len(images), batch_size):
133
+ batch_images = images[i:i+batch_size]
134
+ batch_texts = descriptions[i:i+batch_size]
135
+
136
+ inputs = processor(
137
+ text=batch_texts,
138
+ images=batch_images,
139
+ return_tensors="pt",
140
+ padding=True,
141
+ ).to(device)
142
+
143
+ with torch.no_grad():
144
+ outputs = model(**inputs)
145
+ image_emb = outputs.image_embeds
146
+ image_emb = image_emb / torch.norm(image_emb, dim=1, keepdim=True)
147
+ all_image_embeddings.append(image_emb.cpu())
148
+
149
+ return torch.cat(all_image_embeddings, dim=0).to(device)
150
+
151
+ except Exception as e:
152
+ print(f"Error computing embeddings: {str(e)}")
153
+ return None
154
+
155
+ def search_images_by_text(query_text, top_k=5, data_source="demo"):
156
+ """Search images based on text query"""
157
+ if not query_text.strip():
158
+ return [], "Please enter a search query"
159
+
160
+ try:
161
+ # Choose data source
162
+ if data_source == "custom" and custom_image_emb is not None:
163
+ images = custom_images
164
+ descriptions = custom_descriptions
165
+ image_emb = custom_image_emb
166
+ else:
167
+ images = demo_data['image']
168
+ descriptions = demo_data['text']
169
+ image_emb = demo_image_emb
170
+
171
+ # Process the text query
172
+ inputs = processor(text=[query_text], return_tensors="pt", padding=True).to(device)
173
+
174
+ with torch.no_grad():
175
+ text_features = model.get_text_features(**inputs)
176
+ text_features = text_features / text_features.norm(dim=-1, keepdim=True)
177
+
178
+ # Calculate similarity scores
179
+ similarity = torch.mm(text_features, image_emb.T)
180
+
181
+ # Get top-k matches
182
+ values, indices = similarity[0].topk(min(top_k, len(images)))
183
+
184
+ results = []
185
+ for idx, score in zip(indices, values):
186
+ results.append((images[idx], f"Score: {score.item():.3f}\n{descriptions[idx]}"))
187
+
188
+ status = f"Found {len(results)} matches for: '{query_text}'"
189
+ return results, status
190
+
191
+ except Exception as e:
192
+ return [], f"Error during search: {str(e)}"
193
+
194
+ def search_similar_images(query_image, top_k=5, data_source="demo"):
195
+ """Search similar images based on query image"""
196
+ if query_image is None:
197
+ return [], "Please provide a query image"
198
+
199
+ try:
200
+ # Choose data source
201
+ if data_source == "custom" and custom_image_emb is not None:
202
+ images = custom_images
203
+ descriptions = custom_descriptions
204
+ image_emb = custom_image_emb
205
+ else:
206
+ images = demo_data['image']
207
+ descriptions = demo_data['text']
208
+ image_emb = demo_image_emb
209
+
210
+ # Process the query image
211
+ inputs = processor(images=query_image, return_tensors="pt", padding=True).to(device)
212
+
213
+ with torch.no_grad():
214
+ image_features = model.get_image_features(**inputs)
215
+ image_features = image_features / image_features.norm(dim=-1, keepdim=True)
216
+
217
+ # Calculate similarity scores
218
+ similarity = torch.mm(image_features, image_emb.T)
219
+
220
+ # Get top-k matches
221
+ values, indices = similarity[0].topk(min(top_k, len(images)))
222
+
223
+ results = []
224
+ for idx, score in zip(indices, values):
225
+ results.append((images[idx], f"Score: {score.item():.3f}\n{descriptions[idx]}"))
226
+
227
+ status = f"Found {len(results)} similar images"
228
+ return results, status
229
+
230
+ except Exception as e:
231
+ return [], f"Error during search: {str(e)}"
232
+
233
+ def classify_image(image, labels_text):
234
+ """Classify image with custom labels"""
235
+ if image is None:
236
+ return None, "Please provide an image"
237
+
238
+ if not labels_text.strip():
239
+ return None, "Please provide labels"
240
+
241
+ try:
242
+ labels = [label.strip() for label in labels_text.split('\n') if label.strip()]
243
+
244
+ if not labels:
245
+ return None, "Please provide valid labels"
246
+
247
+ # Prepare text prompts
248
+ text_prompts = [f"a photo of {label}" for label in labels]
249
+
250
+ inputs = processor(
251
+ text=text_prompts,
252
+ images=image,
253
+ return_tensors="pt",
254
+ padding=True,
255
+ ).to(device)
256
+
257
+ with torch.no_grad():
258
+ outputs = model(**inputs)
259
+ logits_per_image = outputs.logits_per_image
260
+ probs = logits_per_image.softmax(dim=1)
261
+
262
+ # Create bar chart
263
+ probabilities = probs[0].cpu().numpy()
264
+
265
+ fig, ax = plt.subplots(figsize=(10, 6))
266
+ bars = ax.barh(labels, probabilities)
267
+ ax.set_xlabel('Probability')
268
+ ax.set_title('Zero-Shot Classification Results')
269
+
270
+ # Color bars based on probability
271
+ for i, bar in enumerate(bars):
272
+ bar.set_color(plt.cm.viridis(probabilities[i]))
273
+
274
+ plt.tight_layout()
275
+
276
+ # Create detailed results text
277
+ results_text = "Classification Results:\n\n"
278
+ sorted_results = sorted(zip(labels, probabilities), key=lambda x: x[1], reverse=True)
279
+
280
+ for label, prob in sorted_results:
281
+ results_text += f"{label}: {prob:.3f} ({prob*100:.1f}%)\n"
282
+
283
+ return fig, results_text
284
+
285
+ except Exception as e:
286
+ return None, f"Error during classification: {str(e)}"
287
+
288
+ def get_random_demo_images():
289
+ """Get random images from current dataset"""
290
+ try:
291
+ if current_data_source == "custom" and custom_images:
292
+ images = custom_images
293
+ descriptions = custom_descriptions
294
+ else:
295
+ images = demo_data['image']
296
+ descriptions = demo_data['text']
297
+
298
+ if len(images) == 0:
299
+ return []
300
+
301
+ # Get random indices
302
+ indices = np.random.choice(len(images), min(6, len(images)), replace=False)
303
+
304
+ results = []
305
+ for idx in indices:
306
+ results.append((images[idx], f"Image {idx}: {descriptions[idx][:100]}..."))
307
+
308
+ return results
309
+
310
+ except Exception as e:
311
+ return []
312
+
313
+ def switch_data_source(choice):
314
+ """Switch between demo and custom data source"""
315
+ global current_data_source
316
+ current_data_source = "demo" if choice == "Demo Dataset" else "custom"
317
+
318
+ if current_data_source == "custom" and not custom_images:
319
+ return "⚠️ Custom folder not loaded. Please load a custom folder first."
320
+ elif current_data_source == "custom":
321
+ return f"βœ… Switched to custom folder ({len(custom_images)} images)"
322
+ else:
323
+ return f"βœ… Switched to demo dataset ({len(demo_data)} images)"
324
+
325
+ # Initialize the model when the module loads
326
+ initialization_status = load_model_and_demo_data()
327
+
328
+ # Create Gradio interface
329
+ with gr.Blocks(title="AI Image Discovery Studio", theme=gr.themes.Soft()) as demo:
330
+ gr.Markdown("""
331
+ # πŸ–ΌοΈ AI Image Discovery Studio
332
+
333
+ Search images using natural language or find visually similar content with CLIP embeddings!
334
+ """)
335
+
336
+ # Status display
337
+ with gr.Row():
338
+ status_display = gr.Textbox(
339
+ value=initialization_status,
340
+ label="System Status",
341
+ interactive=False
342
+ )
343
+
344
+ # Data source selection and custom folder loading
345
+ with gr.Row():
346
+ with gr.Column(scale=1):
347
+ data_source_radio = gr.Radio(
348
+ ["Demo Dataset", "Custom Folder"],
349
+ value="Demo Dataset",
350
+ label="Data Source"
351
+ )
352
+
353
+ folder_path_input = gr.Textbox(
354
+ label="Custom Folder Path",
355
+ placeholder="e.g., /path/to/your/images",
356
+ visible=False
357
+ )
358
+
359
+ load_folder_btn = gr.Button("Load Custom Folder", visible=False)
360
+ folder_status = gr.Textbox(label="Folder Status", visible=False, interactive=False)
361
+
362
+ with gr.Column(scale=2):
363
+ source_status = gr.Textbox(
364
+ value=f"βœ… Using demo dataset ({len(demo_data)} images)",
365
+ label="Current Data Source",
366
+ interactive=False
367
+ )
368
+
369
+ # Show/hide custom folder controls based on selection
370
+ def toggle_folder_controls(choice):
371
+ visible = choice == "Custom Folder"
372
+ return (
373
+ gr.update(visible=visible), # folder_path_input
374
+ gr.update(visible=visible), # load_folder_btn
375
+ gr.update(visible=visible) # folder_status
376
+ )
377
+
378
+ data_source_radio.change(
379
+ toggle_folder_controls,
380
+ inputs=[data_source_radio],
381
+ outputs=[folder_path_input, load_folder_btn, folder_status]
382
+ )
383
+
384
+ # Update data source status
385
+ data_source_radio.change(
386
+ switch_data_source,
387
+ inputs=[data_source_radio],
388
+ outputs=[source_status]
389
+ )
390
+
391
+ # Load custom folder
392
+ load_folder_btn.click(
393
+ load_custom_folder,
394
+ inputs=[folder_path_input],
395
+ outputs=[folder_status]
396
+ )
397
+
398
+ # Main tabs
399
+ with gr.Tabs():
400
+ # Text to Image Search Tab
401
+ with gr.TabItem("πŸ”€ Text to Image Search"):
402
+ gr.Markdown("Enter a text description to find matching images")
403
+
404
+ with gr.Row():
405
+ with gr.Column():
406
+ text_query = gr.Textbox(
407
+ label="Search Query",
408
+ placeholder="e.g., 'Dog running on grass', 'Beautiful sunset over mountains'"
409
+ )
410
+ text_top_k = gr.Slider(1, 10, value=5, step=1, label="Number of Results")
411
+ text_search_btn = gr.Button("πŸ” Search Images", variant="primary")
412
+
413
+ with gr.Column():
414
+ text_search_status = gr.Textbox(label="Search Status", interactive=False)
415
+
416
+ text_results = gr.Gallery(
417
+ label="Search Results",
418
+ show_label=True,
419
+ elem_id="text_search_gallery",
420
+ columns=5,
421
+ rows=1,
422
+ height="auto"
423
+ )
424
+
425
+ # Connect text search
426
+ text_search_btn.click(
427
+ lambda query, top_k, source: search_images_by_text(
428
+ query, top_k, "custom" if source == "Custom Folder" else "demo"
429
+ ),
430
+ inputs=[text_query, text_top_k, data_source_radio],
431
+ outputs=[text_results, text_search_status]
432
+ )
433
+
434
+ # Image to Image Search Tab
435
+ with gr.TabItem("πŸ–ΌοΈ Image to Image Search"):
436
+ gr.Markdown("Upload an image to find visually similar ones")
437
+
438
+ with gr.Row():
439
+ with gr.Column():
440
+ query_image = gr.Image(label="Query Image", type="pil")
441
+ image_top_k = gr.Slider(1, 10, value=5, step=1, label="Number of Results")
442
+ image_search_btn = gr.Button("πŸ” Find Similar Images", variant="primary")
443
+
444
+ with gr.Column():
445
+ image_search_status = gr.Textbox(label="Search Status", interactive=False)
446
+
447
+ image_results = gr.Gallery(
448
+ label="Similar Images",
449
+ show_label=True,
450
+ elem_id="image_search_gallery",
451
+ columns=5,
452
+ rows=1,
453
+ height="auto"
454
+ )
455
+
456
+ # Connect image search
457
+ image_search_btn.click(
458
+ lambda img, top_k, source: search_similar_images(
459
+ img, top_k, "custom" if source == "Custom Folder" else "demo"
460
+ ),
461
+ inputs=[query_image, image_top_k, data_source_radio],
462
+ outputs=[image_results, image_search_status]
463
+ )
464
+
465
+ # Zero-Shot Classification Tab
466
+ with gr.TabItem("🏷️ Zero-Shot Classification"):
467
+ gr.Markdown("Classify an image with custom labels using CLIP")
468
+
469
+ with gr.Row():
470
+ with gr.Column():
471
+ classify_image_input = gr.Image(label="Image to Classify", type="pil")
472
+ labels_input = gr.Textbox(
473
+ label="Classification Labels (one per line)",
474
+ value="cat\ndog\ncar\nbird\nflower",
475
+ lines=5
476
+ )
477
+ classify_btn = gr.Button("πŸ” Classify Image", variant="primary")
478
+
479
+ with gr.Column():
480
+ classification_results = gr.Textbox(
481
+ label="Detailed Results",
482
+ lines=10,
483
+ interactive=False
484
+ )
485
+
486
+ classification_plot = gr.Plot(label="Classification Results")
487
+
488
+ # Connect classification
489
+ classify_btn.click(
490
+ classify_image,
491
+ inputs=[classify_image_input, labels_input],
492
+ outputs=[classification_plot, classification_results]
493
+ )
494
+
495
+ # Dataset Explorer Tab
496
+ with gr.TabItem("πŸ“Š Dataset Explorer"):
497
+ gr.Markdown("Browse through the dataset images")
498
+
499
+ with gr.Row():
500
+ random_sample_btn = gr.Button("🎲 Show Random Sample", variant="primary")
501
+
502
+ explorer_gallery = gr.Gallery(
503
+ label="Dataset Sample",
504
+ show_label=True,
505
+ elem_id="explorer_gallery",
506
+ columns=3,
507
+ rows=2,
508
+ height="auto"
509
+ )
510
+
511
+ # Connect random sampling
512
+ random_sample_btn.click(
513
+ get_random_demo_images,
514
+ outputs=[explorer_gallery]
515
+ )
516
+
517
+ # Launch the app
518
+ if __name__ == "__main__":
519
+ demo.launch()