NEXAS commited on
Commit
5eb07d7
·
verified ·
1 Parent(s): 67a9702

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +40 -50
src/streamlit_app.py CHANGED
@@ -10,9 +10,8 @@ from chromadb.utils import embedding_functions
10
  # Initialize session state
11
  if 'model' not in st.session_state:
12
  device = "cuda" if torch.cuda.is_available() else "cpu"
13
- # Set a custom cache directory for CLIP model weights
14
  cache_dir = "./clip_cache"
15
- os.makedirs(cache_dir, exist_ok=True) # Create cache directory if it doesn't exist
16
  try:
17
  model, preprocess = clip.load("ViT-B/32", device=device, download_root=cache_dir)
18
  except Exception as e:
@@ -25,39 +24,31 @@ if 'model' not in st.session_state:
25
  st.session_state.demo_image_paths = []
26
  st.session_state.user_images = []
27
 
28
- # Initialize ChromaDB client
29
  if 'chroma_client' not in st.session_state:
30
  try:
31
  st.session_state.chroma_client = chromadb.PersistentClient(path="./chroma_db")
32
- # Create or get collections
33
  st.session_state.demo_collection = st.session_state.chroma_client.get_or_create_collection(
34
- name="demo_images",
35
- metadata={"hnsw:space": "cosine"} # Use cosine similarity
36
  )
37
  st.session_state.user_collection = st.session_state.chroma_client.get_or_create_collection(
38
- name="user_images",
39
- metadata={"hnsw:space": "cosine"}
40
  )
41
  except Exception as e:
42
- st.error(f"Failed to initialize ChromaDB collections: {e}")
43
  st.stop()
44
 
45
- # Load demo images into ChromaDB
46
- if not st.session_state.demo_images:
47
  demo_folder = "demo_images"
48
  if os.path.exists(demo_folder):
49
- demo_image_paths = [os.path.join(demo_folder, f) for f in os.listdir(demo_folder) if f.endswith(('.png', '.jpg', '.jpeg'))]
50
- if len(demo_image_paths) > 0:
51
  st.session_state.demo_image_paths = demo_image_paths
52
- st.session_state.demo_images = [Image.open(path) for path in demo_image_paths]
53
-
54
- # Clear existing demo collection to avoid duplicates
55
  st.session_state.demo_collection.delete(ids=[str(i) for i in range(len(demo_image_paths))])
56
-
57
- # Compute and store embeddings
58
- embeddings = []
59
- ids = []
60
- metadatas = []
61
  for i, img in enumerate(st.session_state.demo_images):
62
  img_pre = st.session_state.preprocess(img).unsqueeze(0).to(st.session_state.device)
63
  with torch.no_grad():
@@ -65,37 +56,38 @@ if not st.session_state.demo_images:
65
  embeddings.append(embedding)
66
  ids.append(str(i))
67
  metadatas.append({"path": demo_image_paths[i]})
68
-
69
- # Add to ChromaDB
70
  try:
71
  st.session_state.demo_collection.add(
72
  embeddings=embeddings,
73
  ids=ids,
74
  metadatas=metadatas
75
  )
 
76
  except Exception as e:
77
  st.error(f"Failed to add demo images to ChromaDB: {e}")
78
  else:
79
- st.warning("No images found in 'demo_images' folder. Demo mode will be limited.")
 
 
80
 
81
- # Streamlit UI
82
- st.title("Image Search with CLIP")
83
 
84
  # Mode selection
85
  mode = st.radio("Select mode", ("Search in Demo Images", "Search in My Images"))
86
 
87
- # User images upload
88
  if mode == "Search in My Images":
89
  st.subheader("Upload Your Images")
90
  uploaded_files = st.file_uploader("Choose images", type=['png', 'jpg', 'jpeg'], accept_multiple_files=True)
91
-
92
  if uploaded_files:
93
- # Clear_previous user images and collection
94
  st.session_state.user_images = []
95
  st.session_state.user_collection.delete(ids=[str(i) for i in range(st.session_state.user_collection.count())])
96
-
97
  for i, uploaded_file in enumerate(uploaded_files):
98
- img = Image.open(uploaded_file)
99
  st.session_state.user_images.append(img)
100
  img_pre = st.session_state.preprocess(img).unsqueeze(0).to(st.session_state.device)
101
  with torch.no_grad():
@@ -107,60 +99,58 @@ if mode == "Search in My Images":
107
  metadatas=[{"index": i}]
108
  )
109
  except Exception as e:
110
- st.error(f"Failed to add user image {i} to ChromaDB: {e}")
111
-
112
  if st.session_state.user_collection.count() > 0:
113
- st.success(f"Uploaded {len(st.session_state.user_images)} images successfully.")
114
  else:
115
- st.warning("No images uploaded yet.")
116
 
117
- # Query image upload
118
- st.subheader Snip: st.subheader("Upload Query Image")
119
  query_file = st.file_uploader("Choose a query image", type=['png', 'jpg', 'jpeg'])
120
 
121
  if query_file is not None:
122
- query_img = Image.open(query_file)
123
  st.image(query_img, caption="Query Image", width=200)
124
  query_pre = st.session_state.preprocess(query_img).unsqueeze(0).to(st.session_state.device)
125
  with torch.no_grad():
126
  query_embedding = st.session_state.model.encode_image(query_pre).cpu().numpy().flatten()
127
-
128
  if mode == "Search in Demo Images":
129
  if st.session_state.demo_collection.count() > 0:
130
- # Query ChromaDB
131
  results = st.session_state.demo_collection.query(
132
  query_embeddings=[query_embedding],
133
  n_results=min(5, st.session_state.demo_collection.count())
134
  )
135
  distances = results['distances'][0]
136
  ids = results['ids'][0]
137
- similarities = [1 - dist for dist in distances] # Convert distance to similarity
138
-
139
- st.subheader("Top 5 Similar Images")
140
  cols = st.columns(5)
141
  for i, (idx, sim) in enumerate(zip(ids, similarities)):
142
  img_idx = int(idx)
143
  with cols[i]:
144
  st.image(st.session_state.demo_images[img_idx], caption=f"Similarity: {sim:.4f}", width=150)
145
  else:
146
- st.error("No demo images available. Please check the 'demo_images' folder.")
147
-
148
  elif mode == "Search in My Images":
149
  if st.session_state.user_collection.count() > 0:
150
- # Query ChromaDB
151
  results = st.session_state.user_collection.query(
152
  query_embeddings=[query_embedding],
153
  n_results=min(5, st.session_state.user_collection.count())
154
  )
155
  distances = results['distances'][0]
156
  ids = results['ids'][0]
157
- similarities = [1 - dist for dist in distances] # Convert distance to similarity
158
-
159
- st.subheader("Top 5 Similar Images")
160
  cols = st.columns(5)
161
  for i, (idx, sim) in enumerate(zip(ids, similarities)):
162
  img_idx = int(idx)
163
  with cols[i]:
164
  st.image(st.session_state.user_images[img_idx], caption=f"Similarity: {sim:.4f}", width=150)
165
  else:
166
- st.error("No user images uploaded yet. Please upload images first.")
 
10
  # Initialize session state
11
  if 'model' not in st.session_state:
12
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
13
  cache_dir = "./clip_cache"
14
+ os.makedirs(cache_dir, exist_ok=True)
15
  try:
16
  model, preprocess = clip.load("ViT-B/32", device=device, download_root=cache_dir)
17
  except Exception as e:
 
24
  st.session_state.demo_image_paths = []
25
  st.session_state.user_images = []
26
 
27
+ # Initialize ChromaDB
28
  if 'chroma_client' not in st.session_state:
29
  try:
30
  st.session_state.chroma_client = chromadb.PersistentClient(path="./chroma_db")
 
31
  st.session_state.demo_collection = st.session_state.chroma_client.get_or_create_collection(
32
+ name="demo_images", metadata={"hnsw:space": "cosine"}
 
33
  )
34
  st.session_state.user_collection = st.session_state.chroma_client.get_or_create_collection(
35
+ name="user_images", metadata={"hnsw:space": "cosine"}
 
36
  )
37
  except Exception as e:
38
+ st.error(f"Failed to initialize ChromaDB: {e}")
39
  st.stop()
40
 
41
+ # Load demo images only once
42
+ if not st.session_state.get("demo_images_loaded", False):
43
  demo_folder = "demo_images"
44
  if os.path.exists(demo_folder):
45
+ demo_image_paths = [os.path.join(demo_folder, f) for f in os.listdir(demo_folder) if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
46
+ if demo_image_paths:
47
  st.session_state.demo_image_paths = demo_image_paths
48
+ st.session_state.demo_images = [Image.open(path).convert("RGB") for path in demo_image_paths]
 
 
49
  st.session_state.demo_collection.delete(ids=[str(i) for i in range(len(demo_image_paths))])
50
+
51
+ embeddings, ids, metadatas = [], [], []
 
 
 
52
  for i, img in enumerate(st.session_state.demo_images):
53
  img_pre = st.session_state.preprocess(img).unsqueeze(0).to(st.session_state.device)
54
  with torch.no_grad():
 
56
  embeddings.append(embedding)
57
  ids.append(str(i))
58
  metadatas.append({"path": demo_image_paths[i]})
59
+
 
60
  try:
61
  st.session_state.demo_collection.add(
62
  embeddings=embeddings,
63
  ids=ids,
64
  metadatas=metadatas
65
  )
66
+ st.session_state.demo_images_loaded = True
67
  except Exception as e:
68
  st.error(f"Failed to add demo images to ChromaDB: {e}")
69
  else:
70
+ st.warning("No images found in 'demo_images' folder.")
71
+ else:
72
+ st.warning("Folder 'demo_images' does not exist.")
73
 
74
+ # UI title
75
+ st.title("🔍 Image Search with CLIP")
76
 
77
  # Mode selection
78
  mode = st.radio("Select mode", ("Search in Demo Images", "Search in My Images"))
79
 
80
+ # Upload user images
81
  if mode == "Search in My Images":
82
  st.subheader("Upload Your Images")
83
  uploaded_files = st.file_uploader("Choose images", type=['png', 'jpg', 'jpeg'], accept_multiple_files=True)
84
+
85
  if uploaded_files:
 
86
  st.session_state.user_images = []
87
  st.session_state.user_collection.delete(ids=[str(i) for i in range(st.session_state.user_collection.count())])
88
+
89
  for i, uploaded_file in enumerate(uploaded_files):
90
+ img = Image.open(uploaded_file).convert("RGB")
91
  st.session_state.user_images.append(img)
92
  img_pre = st.session_state.preprocess(img).unsqueeze(0).to(st.session_state.device)
93
  with torch.no_grad():
 
99
  metadatas=[{"index": i}]
100
  )
101
  except Exception as e:
102
+ st.error(f"Failed to add image {i}: {e}")
103
+
104
  if st.session_state.user_collection.count() > 0:
105
+ st.success(f"Uploaded {len(st.session_state.user_images)} images.")
106
  else:
107
+ st.warning("Upload failed.")
108
 
109
+ # Query image
110
+ st.subheader("Upload Query Image")
111
  query_file = st.file_uploader("Choose a query image", type=['png', 'jpg', 'jpeg'])
112
 
113
  if query_file is not None:
114
+ query_img = Image.open(query_file).convert("RGB")
115
  st.image(query_img, caption="Query Image", width=200)
116
  query_pre = st.session_state.preprocess(query_img).unsqueeze(0).to(st.session_state.device)
117
  with torch.no_grad():
118
  query_embedding = st.session_state.model.encode_image(query_pre).cpu().numpy().flatten()
119
+
120
  if mode == "Search in Demo Images":
121
  if st.session_state.demo_collection.count() > 0:
 
122
  results = st.session_state.demo_collection.query(
123
  query_embeddings=[query_embedding],
124
  n_results=min(5, st.session_state.demo_collection.count())
125
  )
126
  distances = results['distances'][0]
127
  ids = results['ids'][0]
128
+ similarities = [1 - dist for dist in distances]
129
+
130
+ st.subheader("Top 5 Similar Demo Images")
131
  cols = st.columns(5)
132
  for i, (idx, sim) in enumerate(zip(ids, similarities)):
133
  img_idx = int(idx)
134
  with cols[i]:
135
  st.image(st.session_state.demo_images[img_idx], caption=f"Similarity: {sim:.4f}", width=150)
136
  else:
137
+ st.error("No demo images available.")
138
+
139
  elif mode == "Search in My Images":
140
  if st.session_state.user_collection.count() > 0:
 
141
  results = st.session_state.user_collection.query(
142
  query_embeddings=[query_embedding],
143
  n_results=min(5, st.session_state.user_collection.count())
144
  )
145
  distances = results['distances'][0]
146
  ids = results['ids'][0]
147
+ similarities = [1 - dist for dist in distances]
148
+
149
+ st.subheader("Top 5 Similar Uploaded Images")
150
  cols = st.columns(5)
151
  for i, (idx, sim) in enumerate(zip(ids, similarities)):
152
  img_idx = int(idx)
153
  with cols[i]:
154
  st.image(st.session_state.user_images[img_idx], caption=f"Similarity: {sim:.4f}", width=150)
155
  else:
156
+ st.error("Please upload some images first.")