Spaces:
Running
Running
File size: 5,379 Bytes
5a30220 b5b1d61 5a30220 b5b1d61 5a30220 b5b1d61 5a30220 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
from typing import List, Dict, Optional, Tuple
import faiss
import angle_emb
import torch
import numpy as np
from tqdm import tqdm
from datasets import Dataset
class FlickrAngleSearch:
def __init__(self, model_name: str = "WhereIsAI/UAE-Large-V1", device: str = "cuda:0"):
"""Initialize the search engine with model and empty index"""
self.model = angle_emb.AnglE(model_name, pooling_strategy='cls', device=device)
self._index: Optional[faiss.IndexFlatIP] = None
self.captions: Optional[List[str]] = None
self.caption2image: Optional[Dict[str, int]] = None
self.ds: Optional[Dataset] = None
def index(self, dataset: Dataset) -> "FlickrAngleSearch":
"""Build the search index from a dataset"""
self.ds = dataset
# Extract unique captions and build caption->image mapping
captions: List[str] = []
caption2image: Dict[str, int] = {}
for i, example in enumerate(tqdm(dataset)):
for caption in example['caption']:
if caption not in caption2image:
captions.append(caption)
caption2image[caption] = i
self.captions = captions
self.caption2image = caption2image
# Encode captions
print(f"Encoding {len(captions)} unique captions...")
caption_embeddings = self.encode(captions)
# Build FAISS index
dimension = caption_embeddings.shape[1]
self._index = faiss.IndexFlatIP(dimension)
self._index.add(caption_embeddings)
return self
@classmethod
def from_preindexed(cls, index_path: str, captions_path: str, caption2image_path: str, dataset: Dataset, device: str = "cpu") -> "FlickrAngleSearch":
"""Load a pre-built index and mappings"""
instance = cls(device=device)
instance._index = faiss.read_index(index_path)
instance.captions = torch.load(captions_path)
instance.caption2image = torch.load(caption2image_path)
instance.ds = dataset
return instance
def save_index(self, index_path: str, captions_path: str, caption2image_path: str) -> None:
"""Save the index and mappings to disk"""
faiss.write_index(self._index, index_path)
torch.save(self.captions, captions_path)
torch.save(self.caption2image, caption2image_path)
def encode(self, texts: List[str], batch_size: int = 32) -> np.ndarray:
"""Encode a list of texts to embeddings"""
embeddings: List[np.ndarray] = []
for i in tqdm(range(0, len(texts), batch_size), desc="Encoding texts"):
batch = texts[i:i + batch_size]
with torch.no_grad():
embs = self.model.encode(batch, to_numpy=True, device=self.model.device)
embeddings.extend(embs)
return np.stack(embeddings)
def search(self, query: str, k: int = 5) -> List[Tuple[float, str, int]]:
"""
Search for the top-k most relevant captions and their corresponding images
Args:
query: Text query to search for
k: Number of results to return
Returns:
List of (score, caption, image_index) tuples
"""
# Encode the query text
query_embedding = self.encode([query])
# Search the index
scores, indices = self._index.search(query_embedding, k)
# Get the results
results: List[Tuple[float, str, int]] = []
for score, idx in zip(scores[0], indices[0]):
caption = self.captions[idx]
image_idx = self.caption2image[caption]
results.append((float(score), caption, image_idx))
return results
if __name__ == "__main__":
import os
import gradio as gr
from datasets import load_dataset
from huggingface_hub import snapshot_download
local_dir = snapshot_download(repo_id='WhereIsAI/angle-flickr-index')
ds = load_dataset("nlphuji/flickr30k", split='test')
search = FlickrAngleSearch.from_preindexed(
os.path.join(local_dir, 'index.faiss'),
os.path.join(local_dir, 'captions.pt'),
os.path.join(local_dir, 'caption2image.pt'),
ds,
device='cpu'
)
def search_and_display(query, num_results=5):
results = search.search(query, k=num_results)
images = []
captions = []
similarities = []
visited_images = set()
for similarity, caption, image_idx in results:
if image_idx not in visited_images:
visited_images.add(image_idx)
images.append(ds[image_idx]['image'])
captions.append(caption)
similarities.append(f"{similarity:.4f}")
return images, captions, similarities
demo = gr.Interface(
fn=search_and_display,
inputs=[
gr.Textbox(label="Search Query"),
gr.Slider(minimum=1, maximum=10, value=5, step=1, label="Number of Results")
],
outputs=[
gr.Gallery(label="Top Results"),
gr.Dataframe(headers=["Caption"], label="Captions"),
gr.Dataframe(headers=["Similarity Score"], label="Similarity Scores")
],
title="Flickr Image Search",
description="Search through Flickr images using natural language queries"
)
demo.launch(share=True)
|