NEXAS commited on
Commit
67a9702
·
verified ·
1 Parent(s): 4373071

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +93 -28
src/streamlit_app.py CHANGED
@@ -4,21 +4,45 @@ import clip
4
  from PIL import Image
5
  import os
6
  import numpy as np
 
 
7
 
8
  # Initialize session state
9
  if 'model' not in st.session_state:
10
  device = "cuda" if torch.cuda.is_available() else "cpu"
11
- model, preprocess = clip.load("ViT-B/32", device=device)
 
 
 
 
 
 
 
12
  st.session_state.model = model
13
  st.session_state.preprocess = preprocess
14
  st.session_state.device = device
15
  st.session_state.demo_images = []
16
- st.session_state.demo_encodings = []
17
  st.session_state.demo_image_paths = []
18
  st.session_state.user_images = []
19
- st.session_state.user_encodings = []
20
 
21
- # Load demo images
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  if not st.session_state.demo_images:
23
  demo_folder = "demo_images"
24
  if os.path.exists(demo_folder):
@@ -26,10 +50,31 @@ if not st.session_state.demo_images:
26
  if len(demo_image_paths) > 0:
27
  st.session_state.demo_image_paths = demo_image_paths
28
  st.session_state.demo_images = [Image.open(path) for path in demo_image_paths]
29
- demo_preprocessed = [st.session_state.preprocess(img).unsqueeze(0).to(st.session_state.device) for img in st.session_state.demo_images]
30
- with torch.no_grad():
31
- demo_encodings = [st.session_state.model.encode_image(img) for img in demo_preprocessed]
32
- st.session_state.demo_encodings = torch.cat(demo_encodings, dim=0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  else:
34
  st.warning("No images found in 'demo_images' folder. Demo mode will be limited.")
35
 
@@ -45,26 +90,32 @@ if mode == "Search in My Images":
45
  uploaded_files = st.file_uploader("Choose images", type=['png', 'jpg', 'jpeg'], accept_multiple_files=True)
46
 
47
  if uploaded_files:
48
- # Clear previous user images to avoid duplicates
49
  st.session_state.user_images = []
50
- st.session_state.user_encodings = []
51
 
52
- for uploaded_file in uploaded_files:
53
  img = Image.open(uploaded_file)
54
  st.session_state.user_images.append(img)
55
  img_pre = st.session_state.preprocess(img).unsqueeze(0).to(st.session_state.device)
56
  with torch.no_grad():
57
- encoding = st.session_state.model.encode_image(img_pre)
58
- st.session_state.user_encodings.append(encoding)
 
 
 
 
 
 
 
59
 
60
- if st.session_state.user_encodings:
61
- st.session_state.user_encodings = torch.cat(st.session_state.user_encodings, dim=0)
62
  st.success(f"Uploaded {len(st.session_state.user_images)} images successfully.")
63
  else:
64
  st.warning("No images uploaded yet.")
65
 
66
  # Query image upload
67
- st.subheader("Upload Query Image")
68
  query_file = st.file_uploader("Choose a query image", type=['png', 'jpg', 'jpeg'])
69
 
70
  if query_file is not None:
@@ -72,30 +123,44 @@ if query_file is not None:
72
  st.image(query_img, caption="Query Image", width=200)
73
  query_pre = st.session_state.preprocess(query_img).unsqueeze(0).to(st.session_state.device)
74
  with torch.no_grad():
75
- query_encoding = st.session_state.model.encode_image(query_pre)
76
 
77
  if mode == "Search in Demo Images":
78
- if st.session_state.demo_encodings is not None and len(st.session_state.demo_encodings) > 0:
79
- similarities = (st.session_state.demo_encodings @ query_encoding.T).squeeze()
80
- top_indices = torch.topk(similarities, min(5, len(similarities))).indices.cpu().numpy()
 
 
 
 
 
 
81
 
82
  st.subheader("Top 5 Similar Images")
83
  cols = st.columns(5)
84
- for i, idx in enumerate(top_indices):
 
85
  with cols[i]:
86
- st.image(st.session_state.demo_images[idx], caption=f"Similarity: {similarities[idx]:.4f}", width=150)
87
  else:
88
  st.error("No demo images available. Please check the 'demo_images' folder.")
89
 
90
  elif mode == "Search in My Images":
91
- if st.session_state.user_encodings is not None and len(st.session_state.user_encodings) > 0:
92
- similarities = (st.session_state.user_encodings @ query_encoding.T).squeeze()
93
- top_indices = torch.topk(similarities, min(5, len(similarities))).indices.cpu().numpy()
 
 
 
 
 
 
94
 
95
  st.subheader("Top 5 Similar Images")
96
  cols = st.columns(5)
97
- for i, idx in enumerate(top_indices):
 
98
  with cols[i]:
99
- st.image(st.session_state.user_images[idx], caption=f"Similarity: {similarities[idx]:.4f}", width=150)
100
  else:
101
- st.error("No user images uploaded yet. Please upload images first.")
 
4
  from PIL import Image
5
  import os
6
  import numpy as np
7
+ import chromadb
8
+ from chromadb.utils import embedding_functions
9
 
10
  # Initialize session state
11
  if 'model' not in st.session_state:
12
  device = "cuda" if torch.cuda.is_available() else "cpu"
13
+ # Set a custom cache directory for CLIP model weights
14
+ cache_dir = "./clip_cache"
15
+ os.makedirs(cache_dir, exist_ok=True) # Create cache directory if it doesn't exist
16
+ try:
17
+ model, preprocess = clip.load("ViT-B/32", device=device, download_root=cache_dir)
18
+ except Exception as e:
19
+ st.error(f"Failed to load CLIP model: {e}")
20
+ st.stop()
21
  st.session_state.model = model
22
  st.session_state.preprocess = preprocess
23
  st.session_state.device = device
24
  st.session_state.demo_images = []
 
25
  st.session_state.demo_image_paths = []
26
  st.session_state.user_images = []
 
27
 
28
+ # Initialize ChromaDB client
29
+ if 'chroma_client' not in st.session_state:
30
+ try:
31
+ st.session_state.chroma_client = chromadb.PersistentClient(path="./chroma_db")
32
+ # Create or get collections
33
+ st.session_state.demo_collection = st.session_state.chroma_client.get_or_create_collection(
34
+ name="demo_images",
35
+ metadata={"hnsw:space": "cosine"} # Use cosine similarity
36
+ )
37
+ st.session_state.user_collection = st.session_state.chroma_client.get_or_create_collection(
38
+ name="user_images",
39
+ metadata={"hnsw:space": "cosine"}
40
+ )
41
+ except Exception as e:
42
+ st.error(f"Failed to initialize ChromaDB collections: {e}")
43
+ st.stop()
44
+
45
+ # Load demo images into ChromaDB
46
  if not st.session_state.demo_images:
47
  demo_folder = "demo_images"
48
  if os.path.exists(demo_folder):
 
50
  if len(demo_image_paths) > 0:
51
  st.session_state.demo_image_paths = demo_image_paths
52
  st.session_state.demo_images = [Image.open(path) for path in demo_image_paths]
53
+
54
+ # Clear existing demo collection to avoid duplicates
55
+ st.session_state.demo_collection.delete(ids=[str(i) for i in range(len(demo_image_paths))])
56
+
57
+ # Compute and store embeddings
58
+ embeddings = []
59
+ ids = []
60
+ metadatas = []
61
+ for i, img in enumerate(st.session_state.demo_images):
62
+ img_pre = st.session_state.preprocess(img).unsqueeze(0).to(st.session_state.device)
63
+ with torch.no_grad():
64
+ embedding = st.session_state.model.encode_image(img_pre).cpu().numpy().flatten()
65
+ embeddings.append(embedding)
66
+ ids.append(str(i))
67
+ metadatas.append({"path": demo_image_paths[i]})
68
+
69
+ # Add to ChromaDB
70
+ try:
71
+ st.session_state.demo_collection.add(
72
+ embeddings=embeddings,
73
+ ids=ids,
74
+ metadatas=metadatas
75
+ )
76
+ except Exception as e:
77
+ st.error(f"Failed to add demo images to ChromaDB: {e}")
78
  else:
79
  st.warning("No images found in 'demo_images' folder. Demo mode will be limited.")
80
 
 
90
  uploaded_files = st.file_uploader("Choose images", type=['png', 'jpg', 'jpeg'], accept_multiple_files=True)
91
 
92
  if uploaded_files:
93
+ # Clear_previous user images and collection
94
  st.session_state.user_images = []
95
+ st.session_state.user_collection.delete(ids=[str(i) for i in range(st.session_state.user_collection.count())])
96
 
97
+ for i, uploaded_file in enumerate(uploaded_files):
98
  img = Image.open(uploaded_file)
99
  st.session_state.user_images.append(img)
100
  img_pre = st.session_state.preprocess(img).unsqueeze(0).to(st.session_state.device)
101
  with torch.no_grad():
102
+ embedding = st.session_state.model.encode_image(img_pre).cpu().numpy().flatten()
103
+ try:
104
+ st.session_state.user_collection.add(
105
+ embeddings=[embedding],
106
+ ids=[str(i)],
107
+ metadatas=[{"index": i}]
108
+ )
109
+ except Exception as e:
110
+ st.error(f"Failed to add user image {i} to ChromaDB: {e}")
111
 
112
+ if st.session_state.user_collection.count() > 0:
 
113
  st.success(f"Uploaded {len(st.session_state.user_images)} images successfully.")
114
  else:
115
  st.warning("No images uploaded yet.")
116
 
117
  # Query image upload
118
+ st.subheader Snip: st.subheader("Upload Query Image")
119
  query_file = st.file_uploader("Choose a query image", type=['png', 'jpg', 'jpeg'])
120
 
121
  if query_file is not None:
 
123
  st.image(query_img, caption="Query Image", width=200)
124
  query_pre = st.session_state.preprocess(query_img).unsqueeze(0).to(st.session_state.device)
125
  with torch.no_grad():
126
+ query_embedding = st.session_state.model.encode_image(query_pre).cpu().numpy().flatten()
127
 
128
  if mode == "Search in Demo Images":
129
+ if st.session_state.demo_collection.count() > 0:
130
+ # Query ChromaDB
131
+ results = st.session_state.demo_collection.query(
132
+ query_embeddings=[query_embedding],
133
+ n_results=min(5, st.session_state.demo_collection.count())
134
+ )
135
+ distances = results['distances'][0]
136
+ ids = results['ids'][0]
137
+ similarities = [1 - dist for dist in distances] # Convert distance to similarity
138
 
139
  st.subheader("Top 5 Similar Images")
140
  cols = st.columns(5)
141
+ for i, (idx, sim) in enumerate(zip(ids, similarities)):
142
+ img_idx = int(idx)
143
  with cols[i]:
144
+ st.image(st.session_state.demo_images[img_idx], caption=f"Similarity: {sim:.4f}", width=150)
145
  else:
146
  st.error("No demo images available. Please check the 'demo_images' folder.")
147
 
148
  elif mode == "Search in My Images":
149
+ if st.session_state.user_collection.count() > 0:
150
+ # Query ChromaDB
151
+ results = st.session_state.user_collection.query(
152
+ query_embeddings=[query_embedding],
153
+ n_results=min(5, st.session_state.user_collection.count())
154
+ )
155
+ distances = results['distances'][0]
156
+ ids = results['ids'][0]
157
+ similarities = [1 - dist for dist in distances] # Convert distance to similarity
158
 
159
  st.subheader("Top 5 Similar Images")
160
  cols = st.columns(5)
161
+ for i, (idx, sim) in enumerate(zip(ids, similarities)):
162
+ img_idx = int(idx)
163
  with cols[i]:
164
+ st.image(st.session_state.user_images[img_idx], caption=f"Similarity: {sim:.4f}", width=150)
165
  else:
166
+ st.error("No user images uploaded yet. Please upload images first.")