from typing import List, Dict, Optional, Tuple import faiss import angle_emb import torch import numpy as np from tqdm import tqdm from datasets import Dataset class FlickrAngleSearch: def __init__(self, model_name: str = "WhereIsAI/UAE-Large-V1", device: str = "cuda:0"): """Initialize the search engine with model and empty index""" self.model = angle_emb.AnglE(model_name, pooling_strategy='cls', device=device) self._index: Optional[faiss.IndexFlatIP] = None self.captions: Optional[List[str]] = None self.caption2image: Optional[Dict[str, int]] = None self.ds: Optional[Dataset] = None def index(self, dataset: Dataset) -> "FlickrAngleSearch": """Build the search index from a dataset""" self.ds = dataset # Extract unique captions and build caption->image mapping captions: List[str] = [] caption2image: Dict[str, int] = {} for i, example in enumerate(tqdm(dataset)): for caption in example['caption']: if caption not in caption2image: captions.append(caption) caption2image[caption] = i self.captions = captions self.caption2image = caption2image # Encode captions print(f"Encoding {len(captions)} unique captions...") caption_embeddings = self.encode(captions) # Build FAISS index dimension = caption_embeddings.shape[1] self._index = faiss.IndexFlatIP(dimension) self._index.add(caption_embeddings) return self @classmethod def from_preindexed(cls, index_path: str, captions_path: str, caption2image_path: str, dataset: Dataset, device: str = "cpu") -> "FlickrAngleSearch": """Load a pre-built index and mappings""" instance = cls(device=device) instance._index = faiss.read_index(index_path) instance.captions = torch.load(captions_path) instance.caption2image = torch.load(caption2image_path) instance.ds = dataset return instance def save_index(self, index_path: str, captions_path: str, caption2image_path: str) -> None: """Save the index and mappings to disk""" faiss.write_index(self._index, index_path) torch.save(self.captions, captions_path) torch.save(self.caption2image, caption2image_path) def encode(self, texts: List[str], batch_size: int = 32) -> np.ndarray: """Encode a list of texts to embeddings""" embeddings: List[np.ndarray] = [] for i in tqdm(range(0, len(texts), batch_size), desc="Encoding texts"): batch = texts[i:i + batch_size] with torch.no_grad(): embs = self.model.encode(batch, to_numpy=True, device=self.model.device) embeddings.extend(embs) return np.stack(embeddings) def search(self, query: str, k: int = 5) -> List[Tuple[float, str, int]]: """ Search for the top-k most relevant captions and their corresponding images Args: query: Text query to search for k: Number of results to return Returns: List of (score, caption, image_index) tuples """ # Encode the query text query_embedding = self.encode([query]) # Search the index scores, indices = self._index.search(query_embedding, k) # Get the results results: List[Tuple[float, str, int]] = [] for score, idx in zip(scores[0], indices[0]): caption = self.captions[idx] image_idx = self.caption2image[caption] results.append((float(score), caption, image_idx)) return results if __name__ == "__main__": import os import gradio as gr from datasets import load_dataset from huggingface_hub import snapshot_download local_dir = snapshot_download(repo_id='WhereIsAI/angle-flickr-index-v2') ds = load_dataset("WhereIsAI/flickr30k-v2", split='train') search = FlickrAngleSearch.from_preindexed( os.path.join(local_dir, 'index.faiss'), os.path.join(local_dir, 'captions.pt'), os.path.join(local_dir, 'caption2image.pt'), ds, device='cpu' ) def search_and_display(query, num_results=5): results = search.search(query, k=num_results) images = [] captions = [] similarities = [] visited_images = set() for similarity, caption, image_idx in results: if image_idx not in visited_images: visited_images.add(image_idx) images.append(ds[image_idx]['image']) captions.append(caption) similarities.append(f"{similarity:.4f}") return images, captions, similarities demo = gr.Interface( fn=search_and_display, inputs=[ gr.Textbox(label="Search Query"), gr.Slider(minimum=1, maximum=10, value=5, step=1, label="Number of Results") ], outputs=[ gr.Gallery(label="Top Results"), gr.Dataframe(headers=["Caption"], label="Captions"), gr.Dataframe(headers=["Similarity Score"], label="Similarity Scores") ], title="Flickr Image Search", description="Search through Flickr images using natural language queries" ) demo.launch(share=True)