Spaces:
Running
Running
from typing import List, Dict, Optional, Tuple | |
import faiss | |
import angle_emb | |
import torch | |
import numpy as np | |
from tqdm import tqdm | |
from datasets import Dataset | |
class FlickrAngleSearch: | |
def __init__(self, model_name: str = "WhereIsAI/UAE-Large-V1", device: str = "cuda:0"): | |
"""Initialize the search engine with model and empty index""" | |
self.model = angle_emb.AnglE(model_name, pooling_strategy='cls', device=device) | |
self._index: Optional[faiss.IndexFlatIP] = None | |
self.captions: Optional[List[str]] = None | |
self.caption2image: Optional[Dict[str, int]] = None | |
self.ds: Optional[Dataset] = None | |
def index(self, dataset: Dataset) -> "FlickrAngleSearch": | |
"""Build the search index from a dataset""" | |
self.ds = dataset | |
# Extract unique captions and build caption->image mapping | |
captions: List[str] = [] | |
caption2image: Dict[str, int] = {} | |
for i, example in enumerate(tqdm(dataset)): | |
for caption in example['caption']: | |
if caption not in caption2image: | |
captions.append(caption) | |
caption2image[caption] = i | |
self.captions = captions | |
self.caption2image = caption2image | |
# Encode captions | |
print(f"Encoding {len(captions)} unique captions...") | |
caption_embeddings = self.encode(captions) | |
# Build FAISS index | |
dimension = caption_embeddings.shape[1] | |
self._index = faiss.IndexFlatIP(dimension) | |
self._index.add(caption_embeddings) | |
return self | |
def from_preindexed(cls, index_path: str, captions_path: str, caption2image_path: str, dataset: Dataset, device: str = "cpu") -> "FlickrAngleSearch": | |
"""Load a pre-built index and mappings""" | |
instance = cls(device=device) | |
instance._index = faiss.read_index(index_path) | |
instance.captions = torch.load(captions_path) | |
instance.caption2image = torch.load(caption2image_path) | |
instance.ds = dataset | |
return instance | |
def save_index(self, index_path: str, captions_path: str, caption2image_path: str) -> None: | |
"""Save the index and mappings to disk""" | |
faiss.write_index(self._index, index_path) | |
torch.save(self.captions, captions_path) | |
torch.save(self.caption2image, caption2image_path) | |
def encode(self, texts: List[str], batch_size: int = 32) -> np.ndarray: | |
"""Encode a list of texts to embeddings""" | |
embeddings: List[np.ndarray] = [] | |
for i in tqdm(range(0, len(texts), batch_size), desc="Encoding texts"): | |
batch = texts[i:i + batch_size] | |
with torch.no_grad(): | |
embs = self.model.encode(batch, to_numpy=True, device=self.model.device) | |
embeddings.extend(embs) | |
return np.stack(embeddings) | |
def search(self, query: str, k: int = 5) -> List[Tuple[float, str, int]]: | |
""" | |
Search for the top-k most relevant captions and their corresponding images | |
Args: | |
query: Text query to search for | |
k: Number of results to return | |
Returns: | |
List of (score, caption, image_index) tuples | |
""" | |
# Encode the query text | |
query_embedding = self.encode([query]) | |
# Search the index | |
scores, indices = self._index.search(query_embedding, k) | |
# Get the results | |
results: List[Tuple[float, str, int]] = [] | |
for score, idx in zip(scores[0], indices[0]): | |
caption = self.captions[idx] | |
image_idx = self.caption2image[caption] | |
results.append((float(score), caption, image_idx)) | |
return results | |
if __name__ == "__main__": | |
import os | |
import gradio as gr | |
from datasets import load_dataset | |
from huggingface_hub import snapshot_download | |
local_dir = snapshot_download(repo_id='WhereIsAI/angle-flickr-index-v2') | |
ds = load_dataset("WhereIsAI/flickr30k-v2", split='train') | |
search = FlickrAngleSearch.from_preindexed( | |
os.path.join(local_dir, 'index.faiss'), | |
os.path.join(local_dir, 'captions.pt'), | |
os.path.join(local_dir, 'caption2image.pt'), | |
ds, | |
device='cpu' | |
) | |
def search_and_display(query, num_results=5): | |
results = search.search(query, k=num_results) | |
images = [] | |
captions = [] | |
similarities = [] | |
visited_images = set() | |
for similarity, caption, image_idx in results: | |
if image_idx not in visited_images: | |
visited_images.add(image_idx) | |
images.append(ds[image_idx]['image']) | |
captions.append(caption) | |
similarities.append(f"{similarity:.4f}") | |
return images, captions, similarities | |
demo = gr.Interface( | |
fn=search_and_display, | |
inputs=[ | |
gr.Textbox(label="Search Query"), | |
gr.Slider(minimum=1, maximum=10, value=5, step=1, label="Number of Results") | |
], | |
outputs=[ | |
gr.Gallery(label="Top Results"), | |
gr.Dataframe(headers=["Caption"], label="Captions"), | |
gr.Dataframe(headers=["Similarity Score"], label="Similarity Scores") | |
], | |
title="Flickr Image Search", | |
description="Search through Flickr images using natural language queries" | |
) | |
demo.launch(share=True) | |