File size: 5,387 Bytes
5a30220
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8c5d59d
5a30220
8c5d59d
5a30220
 
 
 
 
 
 
 
 
 
 
 
 
b5b1d61
 
5a30220
b5b1d61
 
 
5a30220
 
b5b1d61
5a30220
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
from typing import List, Dict, Optional, Tuple

import faiss
import angle_emb
import torch
import numpy as np
from tqdm import tqdm
from datasets import Dataset


class FlickrAngleSearch:
    def __init__(self, model_name: str = "WhereIsAI/UAE-Large-V1", device: str = "cuda:0"):
        """Initialize the search engine with model and empty index"""
        self.model = angle_emb.AnglE(model_name, pooling_strategy='cls', device=device)
        self._index: Optional[faiss.IndexFlatIP] = None
        self.captions: Optional[List[str]] = None
        self.caption2image: Optional[Dict[str, int]] = None
        self.ds: Optional[Dataset] = None

    def index(self, dataset: Dataset) -> "FlickrAngleSearch":
        """Build the search index from a dataset"""
        self.ds = dataset

        # Extract unique captions and build caption->image mapping
        captions: List[str] = []
        caption2image: Dict[str, int] = {}
        for i, example in enumerate(tqdm(dataset)):
            for caption in example['caption']:
                if caption not in caption2image:
                    captions.append(caption)
                    caption2image[caption] = i

        self.captions = captions
        self.caption2image = caption2image

        # Encode captions
        print(f"Encoding {len(captions)} unique captions...")
        caption_embeddings = self.encode(captions)

        # Build FAISS index
        dimension = caption_embeddings.shape[1]
        self._index = faiss.IndexFlatIP(dimension)
        self._index.add(caption_embeddings)

        return self

    @classmethod
    def from_preindexed(cls, index_path: str, captions_path: str, caption2image_path: str, dataset: Dataset, device: str = "cpu") -> "FlickrAngleSearch":
        """Load a pre-built index and mappings"""
        instance = cls(device=device)
        instance._index = faiss.read_index(index_path)
        instance.captions = torch.load(captions_path)
        instance.caption2image = torch.load(caption2image_path)
        instance.ds = dataset
        return instance

    def save_index(self, index_path: str, captions_path: str, caption2image_path: str) -> None:
        """Save the index and mappings to disk"""
        faiss.write_index(self._index, index_path)
        torch.save(self.captions, captions_path)
        torch.save(self.caption2image, caption2image_path)

    def encode(self, texts: List[str], batch_size: int = 32) -> np.ndarray:
        """Encode a list of texts to embeddings"""
        embeddings: List[np.ndarray] = []
        for i in tqdm(range(0, len(texts), batch_size), desc="Encoding texts"):
            batch = texts[i:i + batch_size]
            with torch.no_grad():
                embs = self.model.encode(batch, to_numpy=True, device=self.model.device)
                embeddings.extend(embs)

        return np.stack(embeddings)

    def search(self, query: str, k: int = 5) -> List[Tuple[float, str, int]]:
        """
        Search for the top-k most relevant captions and their corresponding images

        Args:
            query: Text query to search for
            k: Number of results to return

        Returns:
            List of (score, caption, image_index) tuples
        """
        # Encode the query text
        query_embedding = self.encode([query])

        # Search the index
        scores, indices = self._index.search(query_embedding, k)

        # Get the results
        results: List[Tuple[float, str, int]] = []
        for score, idx in zip(scores[0], indices[0]):
            caption = self.captions[idx]
            image_idx = self.caption2image[caption]
            results.append((float(score), caption, image_idx))

        return results


if __name__ == "__main__":
    import os
    import gradio as gr
    from datasets import load_dataset
    from huggingface_hub import snapshot_download

    local_dir = snapshot_download(repo_id='WhereIsAI/angle-flickr-index-2')

    ds = load_dataset("WhereIsAI/flickr30k-v2", split='train')
    search = FlickrAngleSearch.from_preindexed(
        os.path.join(local_dir, 'index.faiss'),
        os.path.join(local_dir, 'captions.pt'),
        os.path.join(local_dir, 'caption2image.pt'),
        ds,
        device='cpu'
    )

    def search_and_display(query, num_results=5):
        results = search.search(query, k=num_results)
        images = []
        captions = []
        similarities = []

        visited_images = set()
        for similarity, caption, image_idx in results:
            if image_idx not in visited_images:
                visited_images.add(image_idx)
                images.append(ds[image_idx]['image'])
            captions.append(caption)
            similarities.append(f"{similarity:.4f}")

        return images, captions, similarities

    demo = gr.Interface(
        fn=search_and_display,
        inputs=[
            gr.Textbox(label="Search Query"),
            gr.Slider(minimum=1, maximum=10, value=5, step=1, label="Number of Results")
        ],
        outputs=[
            gr.Gallery(label="Top Results"),
            gr.Dataframe(headers=["Caption"], label="Captions"),
            gr.Dataframe(headers=["Similarity Score"], label="Similarity Scores")
        ],
        title="Flickr Image Search",
        description="Search through Flickr images using natural language queries"
    )

    demo.launch(share=True)