Spaces:
Sleeping
Sleeping
File size: 7,383 Bytes
9ec4f1c 910eeac 9ec4f1c 910eeac 9ec4f1c 910eeac 624f840 9ec4f1c ba6de20 9ec4f1c e20472a 9ec4f1c e20472a 9ec4f1c e20472a 9ec4f1c ba6de20 910eeac e20472a 9ec4f1c e20472a 9ec4f1c 07167dd 910eeac 9ec4f1c 910eeac 9ec4f1c ba6de20 9ec4f1c ba6de20 9ec4f1c ba6de20 9ec4f1c ba6de20 910eeac |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 |
# ----------------------------------------------------------------------------
# Copyright (c) 2024 Amar Ali-bey
#
# OpenVPRLab: https://github.com/amaralibey/nanoCLIP
#
# Licensed under the MIT License. See LICENSE file in the project root.
# ----------------------------------------------------------------------------
from pathlib import Path
from typing import List, Tuple, Optional
import torch
import torch.nn.functional as F
import faiss
from transformers import AutoTokenizer
import gradio as gr
from text_encoder import TextEncoder
from load_album import AlbumDataset
class ImageSearchEngine:
def __init__(
self,
model_name: str = "sentence-transformers/all-MiniLM-L6-v2",
output_dim: int = 64,
gallery_folder: str = "photos",
device: str = 'cpu'
):
if device == 'cuda' and not torch.cuda.is_available():
print("CUDA is not available. Using CPU instead.")
device = 'cpu'
self.device = torch.device(device)
self.setup_model(model_name, output_dim)
self.setup_gallery(gallery_folder)
def setup_model(self, model_name: str, output_dim: int) -> None:
"""Initialize and load the text encoder model."""
self.txt_encoder = TextEncoder(
output_dim=output_dim,
lang_model=model_name
).to(self.device)
# Load the pre-trained weights for the text encoder
#
weights_path = Path(__file__).parent.resolve() / 'txt_encoder_state_dict.pth'
# check if the weights file exists
if not weights_path.exists():
raise FileNotFoundError(f"Text encoder weights not found: {weights_path}, make sure to run the create_index.py script.")
weights = torch.load(weights_path, map_location=self.device, weights_only=True)
self.txt_encoder.load_state_dict(weights)
self.txt_encoder.eval()
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
def setup_gallery(self, gallery_folder: str) -> None:
"""Setup the image gallery and FAISS index."""
gallery_path = Path(__file__).parent.resolve() / f'gallery/{gallery_folder}'
# check if the gallery folder exists
if not gallery_path.exists():
raise FileNotFoundError(f"Album folder {gallery_path} not found")
# we use the AlbumDataset class to load the image paths (we won't load the images themselves)
# this is more efficient than loading the images directly, because Gradio will load them
# given the paths returned by the search method.
self.dataset = AlbumDataset(gallery_path, transform=None)
# Load the FAISS index
# the index file should be in the same folder as the gallery
# and has the same name as the folder being indexed
index_path = gallery_path.parent / f"{gallery_folder}.faiss"
self.index = faiss.read_index(index_path.as_posix())
@torch.no_grad()
def encode_query(self, query_text: str) -> torch.Tensor:
"""Encode the text query into embeddings."""
inputs = self.tokenizer(query_text, truncation=True, return_tensors="pt")
inputs = inputs['input_ids'].to(self.device)
embedding = self.txt_encoder(inputs)
embedding = F.normalize(embedding, p=2, dim=1)
return embedding.cpu()
def search(self, query_text: str, k: int = 20) -> List[Tuple[str, Optional[str]]]:
"""Search for images matching the query text."""
if len(query_text) < 3: # avoid searching for very short queries
return []
query_embedding = self.encode_query(query_text)
dist, indices = self.index.search(query_embedding, k)
# you can filter results according to a threshold on the distance
return [(self.dataset.imgs[idx], None) for idx in indices[0]]
class GalleryUI:
def __init__(self, search_engine: ImageSearchEngine):
self.search_engine = search_engine
self.css_path = Path(__file__).parent / 'style.css'
def load_css(self) -> str:
"""Load CSS styles from file."""
with open(self.css_path) as f:
return f.read()
def create_interface(self) -> gr.Blocks:
"""Create the Gradio interface."""
custom_theme = gr.themes.Soft(
text_size='lg',
primary_hue="purple",
secondary_hue="gray",
font=[gr.themes.GoogleFont("Inter"), "system-ui", "sans-serif"],
font_mono=["IBM Plex Mono", "monospace"]
).set(
button_primary_background_fill="*primary_300",
button_primary_background_fill_hover="*primary_200",
block_shadow="*shadow_drop_lg",
block_border_width="2px"
)
# with gr.Blocks(css=self.load_css(), theme=gr.themes.Soft(text_size='lg')) as demo:
with gr.Blocks(css=self.load_css(), theme=custom_theme) as demo:
with gr.Column(elem_classes="container"):
self._create_header()
self._create_search_section()
self._create_footer()
self._setup_callbacks(demo)
return demo
def _create_header(self) -> None:
"""Create the header section."""
with gr.Column(elem_classes="header"):
gr.Markdown("# Gallery Search")
gr.Markdown("Search through your collection of photos with AI")
gr.Markdown("`in this demo, you are searching COCO dataset images`")
def _create_search_section(self) -> None:
"""Create the search interface section."""
with gr.Column():
self.query_text = gr.Textbox(
placeholder="Example: Riding my horse [Enter]",
label="Search Query",
elem_classes="search-input",
autofocus=True,
container=False,
interactive=True
)
with gr.Column(elem_classes="gallery"):
self.gallery = gr.Gallery(
label="Search Results",
columns=6,
# min_height=800,
# rows=3,
# height=800,
object_fit="cover",
elem_classes="gallery",
container=False,
show_share_button=False,
)
def _create_footer(self) -> None:
"""Create the footer section."""
with gr.Column(elem_classes="footer"):
gr.Markdown(
"""Created by [Amar Ali-bey](https://amaralibey.github.io) |
[View on GitHub](https://github.com/amaralibey/nanoCLIP)"""
)
def _setup_callbacks(self, demo: gr.Blocks) -> None:
"""Setup the interface callbacks."""
self.query_text.submit(
self.search_engine.search,
inputs=[self.query_text],#, self.number_of_results],
outputs=self.gallery,
show_progress='hidden',
)
search_engine = ImageSearchEngine(
model_name = "sentence-transformers/all-MiniLM-L6-v2",
output_dim = 64,
gallery_folder = "photos",
)
ui = GalleryUI(search_engine)
demo = ui.create_interface()
if __name__ == "__main__":
demo.launch()
|