SeanLee97's picture
Update app.py
c189817 verified
from typing import List, Dict, Optional, Tuple
import faiss
import angle_emb
import torch
import numpy as np
from tqdm import tqdm
from datasets import Dataset
class FlickrAngleSearch:
def __init__(self, model_name: str = "WhereIsAI/UAE-Large-V1", device: str = "cuda:0"):
"""Initialize the search engine with model and empty index"""
self.model = angle_emb.AnglE(model_name, pooling_strategy='cls', device=device)
self._index: Optional[faiss.IndexFlatIP] = None
self.captions: Optional[List[str]] = None
self.caption2image: Optional[Dict[str, int]] = None
self.ds: Optional[Dataset] = None
def index(self, dataset: Dataset) -> "FlickrAngleSearch":
"""Build the search index from a dataset"""
self.ds = dataset
# Extract unique captions and build caption->image mapping
captions: List[str] = []
caption2image: Dict[str, int] = {}
for i, example in enumerate(tqdm(dataset)):
for caption in example['caption']:
if caption not in caption2image:
captions.append(caption)
caption2image[caption] = i
self.captions = captions
self.caption2image = caption2image
# Encode captions
print(f"Encoding {len(captions)} unique captions...")
caption_embeddings = self.encode(captions)
# Build FAISS index
dimension = caption_embeddings.shape[1]
self._index = faiss.IndexFlatIP(dimension)
self._index.add(caption_embeddings)
return self
@classmethod
def from_preindexed(cls, index_path: str, captions_path: str, caption2image_path: str, dataset: Dataset, device: str = "cpu") -> "FlickrAngleSearch":
"""Load a pre-built index and mappings"""
instance = cls(device=device)
instance._index = faiss.read_index(index_path)
instance.captions = torch.load(captions_path)
instance.caption2image = torch.load(caption2image_path)
instance.ds = dataset
return instance
def save_index(self, index_path: str, captions_path: str, caption2image_path: str) -> None:
"""Save the index and mappings to disk"""
faiss.write_index(self._index, index_path)
torch.save(self.captions, captions_path)
torch.save(self.caption2image, caption2image_path)
def encode(self, texts: List[str], batch_size: int = 32) -> np.ndarray:
"""Encode a list of texts to embeddings"""
embeddings: List[np.ndarray] = []
for i in tqdm(range(0, len(texts), batch_size), desc="Encoding texts"):
batch = texts[i:i + batch_size]
with torch.no_grad():
embs = self.model.encode(batch, to_numpy=True, device=self.model.device)
embeddings.extend(embs)
return np.stack(embeddings)
def search(self, query: str, k: int = 5) -> List[Tuple[float, str, int]]:
"""
Search for the top-k most relevant captions and their corresponding images
Args:
query: Text query to search for
k: Number of results to return
Returns:
List of (score, caption, image_index) tuples
"""
# Encode the query text
query_embedding = self.encode([query])
# Search the index
scores, indices = self._index.search(query_embedding, k)
# Get the results
results: List[Tuple[float, str, int]] = []
for score, idx in zip(scores[0], indices[0]):
caption = self.captions[idx]
image_idx = self.caption2image[caption]
results.append((float(score), caption, image_idx))
return results
if __name__ == "__main__":
import os
import gradio as gr
from datasets import load_dataset
from huggingface_hub import snapshot_download
local_dir = snapshot_download(repo_id='WhereIsAI/angle-flickr-index-v2')
ds = load_dataset("WhereIsAI/flickr30k-v2", split='train')
search = FlickrAngleSearch.from_preindexed(
os.path.join(local_dir, 'index.faiss'),
os.path.join(local_dir, 'captions.pt'),
os.path.join(local_dir, 'caption2image.pt'),
ds,
device='cpu'
)
def search_and_display(query, num_results=5):
results = search.search(query, k=num_results)
images = []
captions = []
similarities = []
visited_images = set()
for similarity, caption, image_idx in results:
if image_idx not in visited_images:
visited_images.add(image_idx)
images.append(ds[image_idx]['image'])
captions.append(caption)
similarities.append(f"{similarity:.4f}")
return images, captions, similarities
demo = gr.Interface(
fn=search_and_display,
inputs=[
gr.Textbox(label="Search Query"),
gr.Slider(minimum=1, maximum=10, value=5, step=1, label="Number of Results")
],
outputs=[
gr.Gallery(label="Top Results"),
gr.Dataframe(headers=["Caption"], label="Captions"),
gr.Dataframe(headers=["Similarity Score"], label="Similarity Scores")
],
title="Flickr Image Search",
description="Search through Flickr images using natural language queries"
)
demo.launch(share=True)