File size: 5,797 Bytes
5e4b1fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50fa949
5e4b1fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import chromadb
from chromadb.config import Settings
import torchvision.models as models
import torch
from torchvision import transforms
from PIL import Image
import logging
import streamlit as st
import requests
import json
import uuid
import os

try:

    logging.basicConfig(level=logging.INFO)
    logger = logging.getLogger(__name__)

    @st.cache_resource
    def load_mobilenet_model():
        device = 'cpu'
        model = models.mobilenet_v3_small(pretrained=False)
        model.classifier[3] = torch.nn.Linear(1024, 768)
        model.load_state_dict(torch.load(
            'mobilenet_v3_small_distilled_new_state_dict.pth', map_location=device))
        model.eval().to(device)
        return model

    @st.cache_resource
    def load_chromadb():
        chroma_client = chromadb.PersistentClient(
            path='data', settings=Settings(anonymized_telemetry=False))
        collection = chroma_client.get_collection(name='images')
        return collection

    model = load_mobilenet_model()
    logger.info("MobileNet loaded")
    collection = load_chromadb()
    logger.info("ChromaDB loaded")
    logger.info(
        f"Connected to ChromaDB collection images with {collection.count()} items")

    preprocess = transforms.Compose([
        transforms.Resize((224, 224)),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[
                             0.229, 0.224, 0.225])
    ])

    def get_image_embedding(image):
        if isinstance(image, str):
            img = Image.open(image).convert('RGB')
        else:
            img = Image.open(image).convert('RGB')
        input_tensor = preprocess(img).unsqueeze(0).to('cpu')
        with torch.no_grad():
            student_embedding = model(input_tensor)

        return torch.nn.functional.normalize(student_embedding, p=2, dim=1).squeeze(0).tolist()

    def save_image(image_file):
        unique_filename = f"{image_file.name}"
        save_path = os.path.join('images', unique_filename)
        with open(save_path, "wb") as f:
            f.write(image_file.getbuffer())
        return save_path

    def resize_image(image_path, size=(224, 224)):
        if isinstance(image_path, str):
            img = Image.open(image_path).convert("RGB")
        else:
            # Handle uploaded file
            img = Image.open(image_path).convert("RGB")
        img_resized = img.resize(size, Image.LANCZOS)  # High-quality resizing
        return img_resized

    st.sidebar.header("Upload Images")
    image_files = st.sidebar.file_uploader(
        "Upload images", type=["png", "jpg", "jpeg"], accept_multiple_files=True)
    num_images = st.sidebar.slider(
        "Number of results to return", min_value=1, max_value=10, value=3)

    if image_files:
        st.sidebar.subheader(
            "Add Images to collection")
        if st.sidebar.button("Add uploaded images"):
            for idx, image_file in enumerate(image_files):
                image_embedding = get_image_embedding(image_file)
                saved_path = save_image(image_file)
                unique_id = str(uuid.uuid4())
                metadata = {
                    'path': f'images/{image_file.name}', "type": "photo"
                }
                collection.add(
                    embeddings=[image_embedding],
                    ids=[unique_id],
                    metadatas=[metadata]
                )
                st.sidebar.success(
                    f"Image {image_file.name} added to the collection")

    st.title('Image Search Using  Text')
    st.write(
        "The images stored in this database are sourced from the [COCO 2017 Validation Dataset](https://cocodataset.org/#download).")
    st.write('Enter the text to search for images with matching description')
    text_input = st.text_input("Description", "Road")
    if st.button("Search"):
        if text_input.strip():
            params = {'text': text_input}
            response = requests.get(
                'https://ashish-001-text-embedding-api.hf.space/embedding', params=params)
            if response.status_code == 200:
                logger.info("Embedding returned by API successfully")
                data = json.loads(response.content)
                embedding = data['embedding']
                results = collection.query(
                    query_embeddings=[embedding],
                    n_results=num_images
                )
                images = [results['metadatas'][0][i]['path']
                          for i in range(len(results['metadatas'][0]))]
                distances = [results['distances'][0][i]
                             for i in range(len(results['metadatas'][0]))]
                if images:
                    cols_per_row = 3
                    rows = (len(images)+cols_per_row-1)//cols_per_row
                    for row in range(rows):
                        cols = st.columns(cols_per_row)
                        for col_idx, col in enumerate(cols):
                            img_idx = row*cols_per_row+col_idx
                            if img_idx < len(images):
                                resized_img = resize_image(
                                    images[img_idx], size=(224, 224))
                                col.image(resized_img,
                                          caption=f"Image {img_idx+1}\ndistance {distances[img_idx]}", use_container_width=True)
                else:
                    st.write("No image found")
            else:
                st.write("Please try again later")
                logger.info(f"status code {response.status_code} returned")
        else:
            st.write("Please enter text in the text area")

except Exception as e:
    logger.info(f"Exception occured: {e}")