SeanLee97 commited on
Commit
5a30220
·
1 Parent(s): 490ff19

init commit

Browse files
Files changed (2) hide show
  1. app.py +147 -0
  2. requirements.txt +6 -0
app.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Dict, Optional, Tuple
2
+
3
+ import faiss
4
+ import angle_emb
5
+ import torch
6
+ import numpy as np
7
+ from tqdm import tqdm
8
+ from datasets import Dataset
9
+
10
+
11
+ class FlickrAngleSearch:
12
+ def __init__(self, model_name: str = "WhereIsAI/UAE-Large-V1", device: str = "cuda:0"):
13
+ """Initialize the search engine with model and empty index"""
14
+ self.model = angle_emb.AnglE(model_name, pooling_strategy='cls', device=device)
15
+ self._index: Optional[faiss.IndexFlatIP] = None
16
+ self.captions: Optional[List[str]] = None
17
+ self.caption2image: Optional[Dict[str, int]] = None
18
+ self.ds: Optional[Dataset] = None
19
+
20
+ def index(self, dataset: Dataset) -> "FlickrAngleSearch":
21
+ """Build the search index from a dataset"""
22
+ self.ds = dataset
23
+
24
+ # Extract unique captions and build caption->image mapping
25
+ captions: List[str] = []
26
+ caption2image: Dict[str, int] = {}
27
+ for i, example in enumerate(tqdm(dataset)):
28
+ for caption in example['caption']:
29
+ if caption not in caption2image:
30
+ captions.append(caption)
31
+ caption2image[caption] = i
32
+
33
+ self.captions = captions
34
+ self.caption2image = caption2image
35
+
36
+ # Encode captions
37
+ print(f"Encoding {len(captions)} unique captions...")
38
+ caption_embeddings = self.encode(captions)
39
+
40
+ # Build FAISS index
41
+ dimension = caption_embeddings.shape[1]
42
+ self._index = faiss.IndexFlatIP(dimension)
43
+ self._index.add(caption_embeddings)
44
+
45
+ return self
46
+
47
+ @classmethod
48
+ def from_preindexed(cls, index_path: str, captions_path: str, caption2image_path: str, dataset: Dataset, device: str = "cpu") -> "FlickrAngleSearch":
49
+ """Load a pre-built index and mappings"""
50
+ instance = cls(device=device)
51
+ instance._index = faiss.read_index(index_path)
52
+ instance.captions = torch.load(captions_path)
53
+ instance.caption2image = torch.load(caption2image_path)
54
+ instance.ds = dataset
55
+ return instance
56
+
57
+ def save_index(self, index_path: str, captions_path: str, caption2image_path: str) -> None:
58
+ """Save the index and mappings to disk"""
59
+ faiss.write_index(self._index, index_path)
60
+ torch.save(self.captions, captions_path)
61
+ torch.save(self.caption2image, caption2image_path)
62
+
63
+ def encode(self, texts: List[str], batch_size: int = 32) -> np.ndarray:
64
+ """Encode a list of texts to embeddings"""
65
+ embeddings: List[np.ndarray] = []
66
+ for i in tqdm(range(0, len(texts), batch_size), desc="Encoding texts"):
67
+ batch = texts[i:i + batch_size]
68
+ with torch.no_grad():
69
+ embs = self.model.encode(batch, to_numpy=True, device=self.model.device)
70
+ embeddings.extend(embs)
71
+
72
+ return np.stack(embeddings)
73
+
74
+ def search(self, query: str, k: int = 5) -> List[Tuple[float, str, int]]:
75
+ """
76
+ Search for the top-k most relevant captions and their corresponding images
77
+
78
+ Args:
79
+ query: Text query to search for
80
+ k: Number of results to return
81
+
82
+ Returns:
83
+ List of (score, caption, image_index) tuples
84
+ """
85
+ # Encode the query text
86
+ query_embedding = self.encode([query])
87
+
88
+ # Search the index
89
+ scores, indices = self._index.search(query_embedding, k)
90
+
91
+ # Get the results
92
+ results: List[Tuple[float, str, int]] = []
93
+ for score, idx in zip(scores[0], indices[0]):
94
+ caption = self.captions[idx]
95
+ image_idx = self.caption2image[caption]
96
+ results.append((float(score), caption, image_idx))
97
+
98
+ return results
99
+
100
+
101
+ if __name__ == "__main__":
102
+ import os
103
+ import gradio as gr
104
+ from datasets import load_dataset
105
+ from huggingface_hub import snapshot_download
106
+
107
+ local_dir = snapshot_download(repo_id='WhereIsAI/angle-flickr-index')
108
+
109
+ ds = load_dataset("nlphuji/flickr30k", split='test')
110
+ search = FlickrAngleSearch.from_preindexed(
111
+ os.path.join(local_dir, 'index.faiss'),
112
+ os.path.join(local_dir, 'captions.pt'),
113
+ os.path.join(local_dir, 'caption2image.pt'),
114
+ ds,
115
+ device='cpu'
116
+ )
117
+
118
+ def search_and_display(query, num_results=5):
119
+ results = search.search(query, k=num_results)
120
+ images = []
121
+ captions = []
122
+ similarities = []
123
+
124
+ for similarity, caption, image_idx in results:
125
+ image = ds[image_idx]['image']
126
+ images.append(image)
127
+ captions.append(caption)
128
+ similarities.append(f"{similarity:.4f}")
129
+
130
+ return images, captions, similarities
131
+
132
+ demo = gr.Interface(
133
+ fn=search_and_display,
134
+ inputs=[
135
+ gr.Textbox(label="Search Query"),
136
+ gr.Slider(minimum=1, maximum=10, value=5, step=1, label="Number of Results")
137
+ ],
138
+ outputs=[
139
+ gr.Gallery(label="Top Results"),
140
+ gr.Dataframe(headers=["Caption"], label="Captions"),
141
+ gr.Dataframe(headers=["Similarity Score"], label="Similarity Scores")
142
+ ],
143
+ title="Flickr Image Search",
144
+ description="Search through Flickr images using natural language queries"
145
+ )
146
+
147
+ demo.launch(share=True)
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch
2
+ datasets
3
+ faiss-cpu
4
+ gradio
5
+ huggingface-hub
6
+ angle-emb