NEXAS commited on
Commit
69c4c51
Β·
verified Β·
1 Parent(s): bb61313

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +37 -16
src/streamlit_app.py CHANGED
@@ -6,17 +6,19 @@ import os
6
  import numpy as np
7
  import chromadb
8
  from chromadb.utils import embedding_functions
9
- import tempfile # βœ… For safe writable cache directory
10
 
11
- # Initialize session state
12
  if 'model' not in st.session_state:
13
  device = "cuda" if torch.cuda.is_available() else "cpu"
14
- cache_dir = tempfile.gettempdir() # βœ… Use temp dir instead of restricted local path
 
15
  try:
16
  model, preprocess = clip.load("ViT-B/32", device=device, download_root=cache_dir)
17
  except Exception as e:
18
  st.error(f"Failed to load CLIP model: {e}")
19
  st.stop()
 
20
  st.session_state.model = model
21
  st.session_state.preprocess = preprocess
22
  st.session_state.device = device
@@ -24,10 +26,12 @@ if 'model' not in st.session_state:
24
  st.session_state.demo_image_paths = []
25
  st.session_state.user_images = []
26
 
27
- # Initialize ChromaDB
28
  if 'chroma_client' not in st.session_state:
29
  try:
30
- st.session_state.chroma_client = chromadb.PersistentClient(path="./chroma_db")
 
 
31
  st.session_state.demo_collection = st.session_state.chroma_client.get_or_create_collection(
32
  name="demo_images", metadata={"hnsw:space": "cosine"}
33
  )
@@ -38,15 +42,22 @@ if 'chroma_client' not in st.session_state:
38
  st.error(f"Failed to initialize ChromaDB: {e}")
39
  st.stop()
40
 
41
- # Load demo images once
42
  if not st.session_state.get("demo_images_loaded", False):
43
  demo_folder = "demo_images"
44
  if os.path.exists(demo_folder):
45
- demo_image_paths = [os.path.join(demo_folder, f) for f in os.listdir(demo_folder) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
 
 
46
  if demo_image_paths:
47
  st.session_state.demo_image_paths = demo_image_paths
48
  st.session_state.demo_images = [Image.open(path).convert("RGB") for path in demo_image_paths]
49
- st.session_state.demo_collection.delete(ids=[str(i) for i in range(len(demo_image_paths))])
 
 
 
 
 
50
 
51
  embeddings, ids, metadatas = [], [], []
52
  for i, img in enumerate(st.session_state.demo_images):
@@ -71,27 +82,34 @@ if not st.session_state.get("demo_images_loaded", False):
71
  else:
72
  st.warning("Folder 'demo_images' does not exist.")
73
 
74
- # UI title
75
  st.title("πŸ” Image Search with CLIP")
76
-
77
- # Mode selection
78
  mode = st.radio("Select mode", ("Search in Demo Images", "Search in My Images"))
79
 
80
- # Upload user images
81
  if mode == "Search in My Images":
82
  st.subheader("Upload Your Images")
83
  uploaded_files = st.file_uploader("Choose images", type=['png', 'jpg', 'jpeg'], accept_multiple_files=True)
84
 
85
  if uploaded_files:
86
  st.session_state.user_images = []
87
- st.session_state.user_collection.delete(ids=[str(i) for i in range(st.session_state.user_collection.count())])
 
 
 
 
 
 
 
88
 
89
  for i, uploaded_file in enumerate(uploaded_files):
90
  img = Image.open(uploaded_file).convert("RGB")
91
  st.session_state.user_images.append(img)
 
92
  img_pre = st.session_state.preprocess(img).unsqueeze(0).to(st.session_state.device)
93
  with torch.no_grad():
94
  embedding = st.session_state.model.encode_image(img_pre).cpu().numpy().flatten()
 
95
  try:
96
  st.session_state.user_collection.add(
97
  embeddings=[embedding],
@@ -106,17 +124,19 @@ if mode == "Search in My Images":
106
  else:
107
  st.warning("Upload failed.")
108
 
109
- # Query image
110
  st.subheader("Upload Query Image")
111
  query_file = st.file_uploader("Choose a query image", type=['png', 'jpg', 'jpeg'])
112
 
113
  if query_file is not None:
114
  query_img = Image.open(query_file).convert("RGB")
115
  st.image(query_img, caption="Query Image", width=200)
 
116
  query_pre = st.session_state.preprocess(query_img).unsqueeze(0).to(st.session_state.device)
117
  with torch.no_grad():
118
  query_embedding = st.session_state.model.encode_image(query_pre).cpu().numpy().flatten()
119
 
 
120
  if mode == "Search in Demo Images":
121
  if st.session_state.demo_collection.count() > 0:
122
  results = st.session_state.demo_collection.query(
@@ -132,10 +152,11 @@ if query_file is not None:
132
  for i, (idx, sim) in enumerate(zip(ids, similarities)):
133
  img_idx = int(idx)
134
  with cols[i]:
135
- st.image(st.session_state.demo_images[img_idx], caption=f"Similarity: {sim:.4f}", width=150)
136
  else:
137
  st.error("No demo images available.")
138
 
 
139
  elif mode == "Search in My Images":
140
  if st.session_state.user_collection.count() > 0:
141
  results = st.session_state.user_collection.query(
@@ -151,6 +172,6 @@ if query_file is not None:
151
  for i, (idx, sim) in enumerate(zip(ids, similarities)):
152
  img_idx = int(idx)
153
  with cols[i]:
154
- st.image(st.session_state.user_images[img_idx], caption=f"Similarity: {sim:.4f}", width=150)
155
  else:
156
  st.error("Please upload some images first.")
 
6
  import numpy as np
7
  import chromadb
8
  from chromadb.utils import embedding_functions
9
+ import tempfile
10
 
11
+ # ----- Session Initialization -----
12
  if 'model' not in st.session_state:
13
  device = "cuda" if torch.cuda.is_available() else "cpu"
14
+ cache_dir = tempfile.gettempdir()
15
+
16
  try:
17
  model, preprocess = clip.load("ViT-B/32", device=device, download_root=cache_dir)
18
  except Exception as e:
19
  st.error(f"Failed to load CLIP model: {e}")
20
  st.stop()
21
+
22
  st.session_state.model = model
23
  st.session_state.preprocess = preprocess
24
  st.session_state.device = device
 
26
  st.session_state.demo_image_paths = []
27
  st.session_state.user_images = []
28
 
29
+ # ----- Initialize ChromaDB in Temp Dir -----
30
  if 'chroma_client' not in st.session_state:
31
  try:
32
+ chroma_path = os.path.join(tempfile.gettempdir(), "chroma_db")
33
+ st.session_state.chroma_client = chromadb.PersistentClient(path=chroma_path)
34
+
35
  st.session_state.demo_collection = st.session_state.chroma_client.get_or_create_collection(
36
  name="demo_images", metadata={"hnsw:space": "cosine"}
37
  )
 
42
  st.error(f"Failed to initialize ChromaDB: {e}")
43
  st.stop()
44
 
45
+ # ----- Load Demo Images -----
46
  if not st.session_state.get("demo_images_loaded", False):
47
  demo_folder = "demo_images"
48
  if os.path.exists(demo_folder):
49
+ demo_image_paths = [os.path.join(demo_folder, f) for f in os.listdir(demo_folder)
50
+ if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
51
+
52
  if demo_image_paths:
53
  st.session_state.demo_image_paths = demo_image_paths
54
  st.session_state.demo_images = [Image.open(path).convert("RGB") for path in demo_image_paths]
55
+
56
+ # Clear previous collection
57
+ try:
58
+ st.session_state.demo_collection.delete(ids=[str(i) for i in range(len(demo_image_paths))])
59
+ except:
60
+ pass # Collection might be empty
61
 
62
  embeddings, ids, metadatas = [], [], []
63
  for i, img in enumerate(st.session_state.demo_images):
 
82
  else:
83
  st.warning("Folder 'demo_images' does not exist.")
84
 
85
+ # ----- UI -----
86
  st.title("πŸ” Image Search with CLIP")
 
 
87
  mode = st.radio("Select mode", ("Search in Demo Images", "Search in My Images"))
88
 
89
+ # ----- Upload My Images -----
90
  if mode == "Search in My Images":
91
  st.subheader("Upload Your Images")
92
  uploaded_files = st.file_uploader("Choose images", type=['png', 'jpg', 'jpeg'], accept_multiple_files=True)
93
 
94
  if uploaded_files:
95
  st.session_state.user_images = []
96
+
97
+ # Clear user collection
98
+ try:
99
+ st.session_state.user_collection.delete(ids=[
100
+ str(i) for i in range(st.session_state.user_collection.count())
101
+ ])
102
+ except:
103
+ pass
104
 
105
  for i, uploaded_file in enumerate(uploaded_files):
106
  img = Image.open(uploaded_file).convert("RGB")
107
  st.session_state.user_images.append(img)
108
+
109
  img_pre = st.session_state.preprocess(img).unsqueeze(0).to(st.session_state.device)
110
  with torch.no_grad():
111
  embedding = st.session_state.model.encode_image(img_pre).cpu().numpy().flatten()
112
+
113
  try:
114
  st.session_state.user_collection.add(
115
  embeddings=[embedding],
 
124
  else:
125
  st.warning("Upload failed.")
126
 
127
+ # ----- Query Image -----
128
  st.subheader("Upload Query Image")
129
  query_file = st.file_uploader("Choose a query image", type=['png', 'jpg', 'jpeg'])
130
 
131
  if query_file is not None:
132
  query_img = Image.open(query_file).convert("RGB")
133
  st.image(query_img, caption="Query Image", width=200)
134
+
135
  query_pre = st.session_state.preprocess(query_img).unsqueeze(0).to(st.session_state.device)
136
  with torch.no_grad():
137
  query_embedding = st.session_state.model.encode_image(query_pre).cpu().numpy().flatten()
138
 
139
+ # ----- Search in Demo -----
140
  if mode == "Search in Demo Images":
141
  if st.session_state.demo_collection.count() > 0:
142
  results = st.session_state.demo_collection.query(
 
152
  for i, (idx, sim) in enumerate(zip(ids, similarities)):
153
  img_idx = int(idx)
154
  with cols[i]:
155
+ st.image(st.session_state.demo_images[img_idx], caption=f"Sim: {sim:.4f}", width=150)
156
  else:
157
  st.error("No demo images available.")
158
 
159
+ # ----- Search in User Uploads -----
160
  elif mode == "Search in My Images":
161
  if st.session_state.user_collection.count() > 0:
162
  results = st.session_state.user_collection.query(
 
172
  for i, (idx, sim) in enumerate(zip(ids, similarities)):
173
  img_idx = int(idx)
174
  with cols[i]:
175
+ st.image(st.session_state.user_images[img_idx], caption=f"Sim: {sim:.4f}", width=150)
176
  else:
177
  st.error("Please upload some images first.")