File size: 3,140 Bytes
40a2cd4
 
 
 
 
 
 
 
 
 
 
 
 
 
80175e6
40a2cd4
 
 
 
 
 
 
 
 
 
 
 
 
 
80175e6
40a2cd4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import glob
import os

import streamlit as st
from datastore import ChromaStore
from embeddings import Embedding
from PIL import Image
from tqdm import tqdm

from utils import base64_to_image, image_to_base64

##### Image database
root_dir = os.path.join(os.getcwd(), "data")
jpg_files = glob.glob(os.path.join(root_dir, "**", "*.jpg"), recursive=True)
IMAGE_DATABASE = [Image.open(f).resize((64, 64)) for f in jpg_files]


def display_image_database():
    image_database_expander = st.expander(label="Image Database")
    with image_database_expander:
        st.image(IMAGE_DATABASE)


def display_sample_images():
    sample_img_path = os.path.join(os.getcwd(), "sample_imgs")
    sample_images = os.listdir(sample_img_path)

    images = []
    for i, img in enumerate(sample_images):
        images.append(Image.open(os.path.join(sample_img_path, img)).resize((64, 64)))

    st.image(images)


def main():
    st.set_page_config(page_icon="🖼️", page_title="image-search-engine", layout="wide")
    st.markdown(
        """<h1 style="text-align: center;">🔍️ Image Search Engine</h1>""",
        unsafe_allow_html=True,
    )
    st.markdown(
        """<h3 style="text-align: center;">Image to Image search using transformer embeddings</h3>""",
        unsafe_allow_html=True,
    )

    main_layout = st.columns(2)

    with main_layout[0]:
        with st.container(border=True, height=550):
            st.markdown(
                """<h3 style="text-align: center;">Search</h3>""",
                unsafe_allow_html=True,
            )
            upload_img = st.file_uploader(
                label="Query Image",
                accept_multiple_files=False,
                type=["jpg", "png", "jpeg"],
            )

            submit = st.button(label="Submit")
            display_sample_images()

    with main_layout[1]:
        with st.container(border=True, height=550):
            st.markdown(
                """<h3 style="text-align: center;">Results</h3>""",
                unsafe_allow_html=True,
            )
            top_k = st.slider(label="Search top k results", min_value=3, max_value=10)
            if submit and upload_img:
                ## encode uplaoded img
                query_embedding = Embedding.encode_image(Image.open(upload_img))
                ## query vectorstore
                vectorstore = ChromaStore(collection_name="image_store")
                collection = vectorstore.create()
                # print(collection)
                # print(vectorstore.collection_info(collection))
                st.toast("Vectorstore loaded successfully", icon="✅")
                results = vectorstore.query(
                    collection,
                    query_embedding,
                    top_k=top_k,
                )
                ## show results

                res_images = []
                for res in tqdm(results, desc="Results"):
                    res_images.append(res[0])

                st.image(res_images)
            else:
                st.warning("Please upload an image")

    display_image_database()


if __name__ == "__main__":
    main()