File size: 7,383 Bytes
9ec4f1c
 
 
 
 
 
 
910eeac
 
9ec4f1c
 
910eeac
9ec4f1c
 
 
 
 
910eeac
624f840
 
9ec4f1c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ba6de20
9ec4f1c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e20472a
 
 
 
 
 
 
 
 
 
 
 
 
 
9ec4f1c
 
 
 
 
 
 
 
 
 
 
 
 
e20472a
 
9ec4f1c
 
 
 
 
e20472a
9ec4f1c
 
 
ba6de20
 
910eeac
e20472a
 
 
9ec4f1c
 
e20472a
 
 
 
9ec4f1c
 
 
07167dd
910eeac
9ec4f1c
 
 
 
 
 
 
910eeac
9ec4f1c
 
 
ba6de20
9ec4f1c
ba6de20
9ec4f1c
 
 
ba6de20
 
 
9ec4f1c
 
 
 
ba6de20
 
 
 
910eeac
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
# ----------------------------------------------------------------------------
# Copyright (c) 2024 Amar Ali-bey
#
# OpenVPRLab: https://github.com/amaralibey/nanoCLIP
#
# Licensed under the MIT License. See LICENSE file in the project root.
# ----------------------------------------------------------------------------


from pathlib import Path
from typing import List, Tuple, Optional

import torch
import torch.nn.functional as F
import faiss
from transformers import AutoTokenizer
import gradio as gr

from text_encoder import TextEncoder
from load_album import AlbumDataset

class ImageSearchEngine:
    def __init__(
        self,
        model_name: str = "sentence-transformers/all-MiniLM-L6-v2",
        output_dim: int = 64,
        gallery_folder: str = "photos",
        device: str = 'cpu'
    ):
        if device == 'cuda' and not torch.cuda.is_available():
            print("CUDA is not available. Using CPU instead.")
            device = 'cpu'
        self.device = torch.device(device)
        self.setup_model(model_name, output_dim)
        self.setup_gallery(gallery_folder)
        
    def setup_model(self, model_name: str, output_dim: int) -> None:
        """Initialize and load the text encoder model."""
        self.txt_encoder = TextEncoder(
            output_dim=output_dim,
            lang_model=model_name
        ).to(self.device)
        
        # Load the pre-trained weights for the text encoder
        # 
        weights_path = Path(__file__).parent.resolve() / 'txt_encoder_state_dict.pth'
        # check if the weights file exists
        if not weights_path.exists():
            raise FileNotFoundError(f"Text encoder weights not found: {weights_path}, make sure to run the create_index.py script.")
        weights = torch.load(weights_path, map_location=self.device, weights_only=True)
        self.txt_encoder.load_state_dict(weights)
        self.txt_encoder.eval()
        
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        
    def setup_gallery(self, gallery_folder: str) -> None:
        """Setup the image gallery and FAISS index."""
        gallery_path = Path(__file__).parent.resolve() / f'gallery/{gallery_folder}'
        # check if the gallery folder exists
        if not gallery_path.exists():
            raise FileNotFoundError(f"Album folder {gallery_path} not found")
        # we use the AlbumDataset class to load the image paths (we won't load the images themselves)
        # this is more efficient than loading the images directly, because Gradio will load them 
        # given the paths returned by the search method.
        self.dataset = AlbumDataset(gallery_path, transform=None)
        
        # Load the FAISS index
        # the index file should be in the same folder as the gallery 
        # and has the same name as the folder being indexed 
        index_path = gallery_path.parent / f"{gallery_folder}.faiss"
        self.index = faiss.read_index(index_path.as_posix())
        
    @torch.no_grad()
    def encode_query(self, query_text: str) -> torch.Tensor:
        """Encode the text query into embeddings."""
        inputs = self.tokenizer(query_text, truncation=True, return_tensors="pt")
        inputs = inputs['input_ids'].to(self.device)
        
        embedding = self.txt_encoder(inputs)
        embedding = F.normalize(embedding, p=2, dim=1)
        return embedding.cpu()
    
    def search(self, query_text: str, k: int = 20) -> List[Tuple[str, Optional[str]]]:
        """Search for images matching the query text."""
        if len(query_text) < 3: # avoid searching for very short queries
            return []
            
        query_embedding = self.encode_query(query_text)
        dist, indices = self.index.search(query_embedding, k)
        # you can filter results according to a threshold on the distance
        return [(self.dataset.imgs[idx], None) for idx in indices[0]]

class GalleryUI:
    def __init__(self, search_engine: ImageSearchEngine):
        self.search_engine = search_engine
        self.css_path = Path(__file__).parent / 'style.css'
        
    def load_css(self) -> str:
        """Load CSS styles from file."""
        with open(self.css_path) as f:
            return f.read()
            
    def create_interface(self) -> gr.Blocks:
        """Create the Gradio interface."""
        custom_theme = gr.themes.Soft(
            text_size='lg',
            primary_hue="purple",
            secondary_hue="gray",
            font=[gr.themes.GoogleFont("Inter"), "system-ui", "sans-serif"],
            font_mono=["IBM Plex Mono", "monospace"]
        ).set(
            button_primary_background_fill="*primary_300",
            button_primary_background_fill_hover="*primary_200",
            block_shadow="*shadow_drop_lg",
            block_border_width="2px"
        )
        # with gr.Blocks(css=self.load_css(), theme=gr.themes.Soft(text_size='lg')) as demo:
        with gr.Blocks(css=self.load_css(), theme=custom_theme) as demo:
            with gr.Column(elem_classes="container"):
                self._create_header()
                self._create_search_section()
                self._create_footer()
                
            self._setup_callbacks(demo)
            return demo
    
    def _create_header(self) -> None:
        """Create the header section."""
        with gr.Column(elem_classes="header"):
            gr.Markdown("# Gallery Search")
            gr.Markdown("Search through your collection of photos with AI")
            gr.Markdown("`in this demo, you are searching COCO dataset images`")

    
    def _create_search_section(self) -> None:
        """Create the search interface section."""
        with gr.Column():
            self.query_text = gr.Textbox(
                placeholder="Example: Riding my horse [Enter]",
                label="Search Query",
                elem_classes="search-input",
                autofocus=True,
                container=False,
                interactive=True
            )
            
        with gr.Column(elem_classes="gallery"):
        
            self.gallery = gr.Gallery(
                label="Search Results",
                columns=6,
                # min_height=800,
                # rows=3,
                # height=800,
                object_fit="cover",
                elem_classes="gallery",
                container=False,
                show_share_button=False,
            )
    
    def _create_footer(self) -> None:
        """Create the footer section."""
        with gr.Column(elem_classes="footer"):
            gr.Markdown(
                """Created by [Amar Ali-bey](https://amaralibey.github.io) | 
                [View on GitHub](https://github.com/amaralibey/nanoCLIP)"""
            )
    
    def _setup_callbacks(self, demo: gr.Blocks) -> None:
        """Setup the interface callbacks."""
        self.query_text.submit(
            self.search_engine.search,
            inputs=[self.query_text],#, self.number_of_results],
            outputs=self.gallery,
            show_progress='hidden',
        )


search_engine = ImageSearchEngine(
        model_name = "sentence-transformers/all-MiniLM-L6-v2",
        output_dim = 64,
        gallery_folder = "photos",
    )
ui = GalleryUI(search_engine)
demo = ui.create_interface()

if __name__ == "__main__":
    demo.launch()