NEXAS commited on
Commit
f0e3479
·
verified ·
1 Parent(s): 69c4c51

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +94 -141
src/streamlit_app.py CHANGED
@@ -5,20 +5,16 @@ from PIL import Image
5
  import os
6
  import numpy as np
7
  import chromadb
8
- from chromadb.utils import embedding_functions
9
  import tempfile
10
 
11
- # ----- Session Initialization -----
 
 
 
 
12
  if 'model' not in st.session_state:
13
  device = "cuda" if torch.cuda.is_available() else "cpu"
14
- cache_dir = tempfile.gettempdir()
15
-
16
- try:
17
- model, preprocess = clip.load("ViT-B/32", device=device, download_root=cache_dir)
18
- except Exception as e:
19
- st.error(f"Failed to load CLIP model: {e}")
20
- st.stop()
21
-
22
  st.session_state.model = model
23
  st.session_state.preprocess = preprocess
24
  st.session_state.device = device
@@ -26,21 +22,15 @@ if 'model' not in st.session_state:
26
  st.session_state.demo_image_paths = []
27
  st.session_state.user_images = []
28
 
29
- # ----- Initialize ChromaDB in Temp Dir -----
30
  if 'chroma_client' not in st.session_state:
31
- try:
32
- chroma_path = os.path.join(tempfile.gettempdir(), "chroma_db")
33
- st.session_state.chroma_client = chromadb.PersistentClient(path=chroma_path)
34
-
35
- st.session_state.demo_collection = st.session_state.chroma_client.get_or_create_collection(
36
- name="demo_images", metadata={"hnsw:space": "cosine"}
37
- )
38
- st.session_state.user_collection = st.session_state.chroma_client.get_or_create_collection(
39
- name="user_images", metadata={"hnsw:space": "cosine"}
40
- )
41
- except Exception as e:
42
- st.error(f"Failed to initialize ChromaDB: {e}")
43
- st.stop()
44
 
45
  # ----- Load Demo Images -----
46
  if not st.session_state.get("demo_images_loaded", False):
@@ -48,130 +38,93 @@ if not st.session_state.get("demo_images_loaded", False):
48
  if os.path.exists(demo_folder):
49
  demo_image_paths = [os.path.join(demo_folder, f) for f in os.listdir(demo_folder)
50
  if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
 
 
51
 
52
- if demo_image_paths:
53
- st.session_state.demo_image_paths = demo_image_paths
54
- st.session_state.demo_images = [Image.open(path).convert("RGB") for path in demo_image_paths]
55
-
56
- # Clear previous collection
57
- try:
58
- st.session_state.demo_collection.delete(ids=[str(i) for i in range(len(demo_image_paths))])
59
- except:
60
- pass # Collection might be empty
61
-
62
- embeddings, ids, metadatas = [], [], []
63
- for i, img in enumerate(st.session_state.demo_images):
64
- img_pre = st.session_state.preprocess(img).unsqueeze(0).to(st.session_state.device)
65
- with torch.no_grad():
66
- embedding = st.session_state.model.encode_image(img_pre).cpu().numpy().flatten()
67
- embeddings.append(embedding)
68
- ids.append(str(i))
69
- metadatas.append({"path": demo_image_paths[i]})
70
-
71
- try:
72
- st.session_state.demo_collection.add(
73
- embeddings=embeddings,
74
- ids=ids,
75
- metadatas=metadatas
76
- )
77
- st.session_state.demo_images_loaded = True
78
- except Exception as e:
79
- st.error(f"Failed to add demo images to ChromaDB: {e}")
80
- else:
81
- st.warning("No images found in 'demo_images' folder.")
82
- else:
83
- st.warning("Folder 'demo_images' does not exist.")
84
 
85
- # ----- UI -----
86
- st.title("🔍 Image Search with CLIP")
87
- mode = st.radio("Select mode", ("Search in Demo Images", "Search in My Images"))
 
 
 
 
 
88
 
89
- # ----- Upload My Images -----
90
- if mode == "Search in My Images":
91
- st.subheader("Upload Your Images")
92
- uploaded_files = st.file_uploader("Choose images", type=['png', 'jpg', 'jpeg'], accept_multiple_files=True)
93
 
94
- if uploaded_files:
 
 
 
 
 
 
 
 
95
  st.session_state.user_images = []
 
 
 
96
 
97
- # Clear user collection
98
- try:
99
- st.session_state.user_collection.delete(ids=[
100
- str(i) for i in range(st.session_state.user_collection.count())
101
- ])
102
- except:
103
- pass
104
-
105
- for i, uploaded_file in enumerate(uploaded_files):
106
- img = Image.open(uploaded_file).convert("RGB")
107
  st.session_state.user_images.append(img)
108
 
109
- img_pre = st.session_state.preprocess(img).unsqueeze(0).to(st.session_state.device)
110
  with torch.no_grad():
111
- embedding = st.session_state.model.encode_image(img_pre).cpu().numpy().flatten()
112
-
113
- try:
114
- st.session_state.user_collection.add(
115
- embeddings=[embedding],
116
- ids=[str(i)],
117
- metadatas=[{"index": i}]
118
- )
119
- except Exception as e:
120
- st.error(f"Failed to add image {i}: {e}")
121
-
122
- if st.session_state.user_collection.count() > 0:
123
- st.success(f"Uploaded {len(st.session_state.user_images)} images.")
124
- else:
125
- st.warning("Upload failed.")
126
-
127
- # ----- Query Image -----
128
- st.subheader("Upload Query Image")
129
- query_file = st.file_uploader("Choose a query image", type=['png', 'jpg', 'jpeg'])
130
-
131
- if query_file is not None:
132
- query_img = Image.open(query_file).convert("RGB")
133
- st.image(query_img, caption="Query Image", width=200)
134
-
135
- query_pre = st.session_state.preprocess(query_img).unsqueeze(0).to(st.session_state.device)
136
- with torch.no_grad():
137
- query_embedding = st.session_state.model.encode_image(query_pre).cpu().numpy().flatten()
138
-
139
- # ----- Search in Demo -----
140
- if mode == "Search in Demo Images":
141
- if st.session_state.demo_collection.count() > 0:
142
- results = st.session_state.demo_collection.query(
143
- query_embeddings=[query_embedding],
144
- n_results=min(5, st.session_state.demo_collection.count())
145
- )
146
- distances = results['distances'][0]
147
- ids = results['ids'][0]
148
- similarities = [1 - dist for dist in distances]
149
-
150
- st.subheader("Top 5 Similar Demo Images")
151
- cols = st.columns(5)
152
- for i, (idx, sim) in enumerate(zip(ids, similarities)):
153
- img_idx = int(idx)
154
- with cols[i]:
155
- st.image(st.session_state.demo_images[img_idx], caption=f"Sim: {sim:.4f}", width=150)
156
- else:
157
- st.error("No demo images available.")
158
-
159
- # ----- Search in User Uploads -----
160
- elif mode == "Search in My Images":
161
- if st.session_state.user_collection.count() > 0:
162
- results = st.session_state.user_collection.query(
163
- query_embeddings=[query_embedding],
164
- n_results=min(5, st.session_state.user_collection.count())
165
  )
166
- distances = results['distances'][0]
167
- ids = results['ids'][0]
168
- similarities = [1 - dist for dist in distances]
169
-
170
- st.subheader("Top 5 Similar Uploaded Images")
171
- cols = st.columns(5)
172
- for i, (idx, sim) in enumerate(zip(ids, similarities)):
173
- img_idx = int(idx)
174
- with cols[i]:
175
- st.image(st.session_state.user_images[img_idx], caption=f"Sim: {sim:.4f}", width=150)
176
- else:
177
- st.error("Please upload some images first.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
  import os
6
  import numpy as np
7
  import chromadb
 
8
  import tempfile
9
 
10
+ # ----- Setup -----
11
+ CACHE_DIR = tempfile.gettempdir()
12
+ CHROMA_PATH = os.path.join(CACHE_DIR, "chroma_db")
13
+
14
+ # ----- Load CLIP Model -----
15
  if 'model' not in st.session_state:
16
  device = "cuda" if torch.cuda.is_available() else "cpu"
17
+ model, preprocess = clip.load("ViT-B/32", device=device, download_root=CACHE_DIR)
 
 
 
 
 
 
 
18
  st.session_state.model = model
19
  st.session_state.preprocess = preprocess
20
  st.session_state.device = device
 
22
  st.session_state.demo_image_paths = []
23
  st.session_state.user_images = []
24
 
25
+ # ----- Initialize ChromaDB -----
26
  if 'chroma_client' not in st.session_state:
27
+ st.session_state.chroma_client = chromadb.PersistentClient(path=CHROMA_PATH)
28
+ st.session_state.demo_collection = st.session_state.chroma_client.get_or_create_collection(
29
+ name="demo_images", metadata={"hnsw:space": "cosine"}
30
+ )
31
+ st.session_state.user_collection = st.session_state.chroma_client.get_or_create_collection(
32
+ name="user_images", metadata={"hnsw:space": "cosine"}
33
+ )
 
 
 
 
 
 
34
 
35
  # ----- Load Demo Images -----
36
  if not st.session_state.get("demo_images_loaded", False):
 
38
  if os.path.exists(demo_folder):
39
  demo_image_paths = [os.path.join(demo_folder, f) for f in os.listdir(demo_folder)
40
  if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
41
+ st.session_state.demo_images = [Image.open(p).convert("RGB") for p in demo_image_paths]
42
+ st.session_state.demo_image_paths = demo_image_paths
43
 
44
+ st.session_state.demo_collection.delete(ids=[str(i) for i in range(len(demo_image_paths))])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
+ embeddings, ids, metadatas = [], [], []
47
+ for i, img in enumerate(st.session_state.demo_images):
48
+ img_tensor = st.session_state.preprocess(img).unsqueeze(0).to(st.session_state.device)
49
+ with torch.no_grad():
50
+ embedding = st.session_state.model.encode_image(img_tensor).cpu().numpy().flatten()
51
+ embeddings.append(embedding)
52
+ ids.append(str(i))
53
+ metadatas.append({"path": demo_image_paths[i]})
54
 
55
+ st.session_state.demo_collection.add(embeddings=embeddings, ids=ids, metadatas=metadatas)
56
+ st.session_state.demo_images_loaded = True
 
 
57
 
58
+ # ----- UI -----
59
+ st.title("🔎 CLIP Image Search (Text & Image)")
60
+ mode = st.radio("Choose dataset to search in:", ("Demo Images", "My Uploaded Images"))
61
+ query_type = st.radio("Query type:", ("Image", "Text"))
62
+
63
+ # ----- Upload User Images -----
64
+ if mode == "My Uploaded Images":
65
+ uploaded = st.file_uploader("Upload your images", type=['jpg', 'jpeg', 'png'], accept_multiple_files=True)
66
+ if uploaded:
67
  st.session_state.user_images = []
68
+ st.session_state.user_collection.delete(ids=[
69
+ str(i) for i in range(st.session_state.user_collection.count())
70
+ ])
71
 
72
+ for i, file in enumerate(uploaded):
73
+ img = Image.open(file).convert("RGB")
 
 
 
 
 
 
 
 
74
  st.session_state.user_images.append(img)
75
 
76
+ img_tensor = st.session_state.preprocess(img).unsqueeze(0).to(st.session_state.device)
77
  with torch.no_grad():
78
+ embedding = st.session_state.model.encode_image(img_tensor).cpu().numpy().flatten()
79
+
80
+ st.session_state.user_collection.add(
81
+ embeddings=[embedding],
82
+ ids=[str(i)],
83
+ metadatas=[{"index": i}]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  )
85
+
86
+ st.success(f"{len(uploaded)} images uploaded.")
87
+
88
+ # ----- Perform Query -----
89
+ query_embedding = None
90
+ if query_type == "Image":
91
+ img_file = st.file_uploader("Upload query image", type=["jpg", "jpeg", "png"])
92
+ if img_file:
93
+ img = Image.open(img_file).convert("RGB")
94
+ st.image(img, caption="Query Image", width=200)
95
+ img_tensor = st.session_state.preprocess(img).unsqueeze(0).to(st.session_state.device)
96
+ with torch.no_grad():
97
+ query_embedding = st.session_state.model.encode_image(img_tensor).cpu().numpy().flatten()
98
+ elif query_type == "Text":
99
+ text_query = st.text_input("Enter search text:")
100
+ if text_query:
101
+ tokens = clip.tokenize([text_query]).to(st.session_state.device)
102
+ with torch.no_grad():
103
+ query_embedding = st.session_state.model.encode_text(tokens).cpu().numpy().flatten()
104
+
105
+ # ----- Run Search -----
106
+ if query_embedding is not None:
107
+ if mode == "Demo Images":
108
+ collection = st.session_state.demo_collection
109
+ images = st.session_state.demo_images
110
+ else:
111
+ collection = st.session_state.user_collection
112
+ images = st.session_state.user_images
113
+
114
+ if collection.count() > 0:
115
+ results = collection.query(
116
+ query_embeddings=[query_embedding],
117
+ n_results=min(5, collection.count())
118
+ )
119
+ ids = results["ids"][0]
120
+ distances = results["distances"][0]
121
+ similarities = [1 - d for d in distances]
122
+
123
+ st.subheader("Top Matches")
124
+ cols = st.columns(5)
125
+ for i, (img_id, sim) in enumerate(zip(ids, similarities)):
126
+ with cols[i]:
127
+ idx = int(img_id)
128
+ st.image(images[idx], caption=f"Sim: {sim:.3f}", width=150)
129
+ else:
130
+ st.warning("No images found in collection.")