NEXAS commited on
Commit
b8fa391
Β·
verified Β·
1 Parent(s): d7c7b18

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +66 -41
src/streamlit_app.py CHANGED
@@ -7,6 +7,7 @@ import numpy as np
7
  import chromadb
8
  import requests
9
  import tempfile
 
10
 
11
  # ----- Setup -----
12
  CACHE_DIR = tempfile.gettempdir()
@@ -14,6 +15,16 @@ CHROMA_PATH = os.path.join(CACHE_DIR, "chroma_db")
14
  DEMO_DIR = os.path.join(CACHE_DIR, "demo_images")
15
  os.makedirs(DEMO_DIR, exist_ok=True)
16
 
 
 
 
 
 
 
 
 
 
 
17
  # ----- Load CLIP Model -----
18
  if 'model' not in st.session_state:
19
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -32,33 +43,46 @@ if 'chroma_client' not in st.session_state:
32
  name="user_images", metadata={"hnsw:space": "cosine"}
33
  )
34
 
 
35
  st.title("πŸ” CLIP-Based Image Search")
36
 
37
- # Dataset selection
38
  col1, col2 = st.columns(2)
39
- use_demo = col1.button("πŸ“¦ Use Demo Images")
40
- upload_own = col2.button("πŸ“€ Upload Your Images")
41
-
42
- dataset_loaded = False
43
- dataset_name = None
44
-
45
- # ----- Handle Demo Images -----
46
- if use_demo:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  with st.spinner("Downloading and indexing demo images..."):
48
  st.session_state.demo_collection.delete(ids=[str(i) for i in range(50)])
49
- demo_image_paths = []
50
- demo_images = []
51
-
52
  for i in range(50):
53
  path = os.path.join(DEMO_DIR, f"img_{i+1:02}.jpg")
54
  if not os.path.exists(path):
55
  url = f"https://picsum.photos/seed/{i}/1024/768"
56
- response = requests.get(url)
57
- if response.status_code == 200:
58
- with open(path, "wb") as f:
59
- f.write(response.content)
60
- demo_image_paths.append(path)
61
- demo_images.append(Image.open(path).convert("RGB"))
62
 
63
  embeddings, ids, metadatas = [], [], []
64
  for i, img in enumerate(demo_images):
@@ -71,13 +95,12 @@ if use_demo:
71
 
72
  st.session_state.demo_collection.add(embeddings=embeddings, ids=ids, metadatas=metadatas)
73
  st.session_state.demo_images = demo_images
74
- dataset_loaded = True
75
- dataset_name = "demo"
76
 
77
- st.success("Demo images loaded!")
78
 
79
- # ----- Handle User Uploads -----
80
- if upload_own:
81
  uploaded = st.file_uploader("Upload your images", type=["jpg", "jpeg", "png"], accept_multiple_files=True)
82
  if uploaded:
83
  st.session_state.user_collection.delete(ids=[
@@ -85,9 +108,11 @@ if upload_own:
85
  ])
86
  user_images = []
87
  for i, file in enumerate(uploaded):
88
- img = Image.open(file).convert("RGB")
 
 
 
89
  user_images.append(img)
90
-
91
  img_tensor = st.session_state.preprocess(img).unsqueeze(0).to(st.session_state.device)
92
  with torch.no_grad():
93
  embedding = st.session_state.model.encode_image(img_tensor).cpu().numpy().flatten()
@@ -96,13 +121,12 @@ if upload_own:
96
  )
97
 
98
  st.session_state.user_images = user_images
99
- st.success(f"{len(user_images)} images uploaded.")
100
- dataset_loaded = True
101
- dataset_name = "user"
102
 
103
- # ----- Search UI -----
104
- if dataset_loaded:
105
- st.subheader("Search Section")
106
  query_type = st.radio("Search by:", ("Text", "Image"))
107
 
108
  query_embedding = None
@@ -112,18 +136,19 @@ if dataset_loaded:
112
  tokens = clip.tokenize([text_query]).to(st.session_state.device)
113
  with torch.no_grad():
114
  query_embedding = st.session_state.model.encode_text(tokens).cpu().numpy().flatten()
115
- else:
116
- img_file = st.file_uploader("Upload query image", type=["jpg", "jpeg", "png"])
117
- if img_file:
118
- query_img = Image.open(img_file).convert("RGB")
 
119
  st.image(query_img, caption="Query Image", width=200)
120
- img_tensor = st.session_state.preprocess(query_img).unsqueeze(0).to(st.session_state.device)
121
  with torch.no_grad():
122
- query_embedding = st.session_state.model.encode_image(img_tensor).cpu().numpy().flatten()
123
 
124
  # ----- Perform Search -----
125
  if query_embedding is not None:
126
- if dataset_name == "demo":
127
  collection = st.session_state.demo_collection
128
  images = st.session_state.demo_images
129
  else:
@@ -139,12 +164,12 @@ if dataset_loaded:
139
  distances = results["distances"][0]
140
  similarities = [1 - d for d in distances]
141
 
142
- st.subheader("Top Matches")
143
  cols = st.columns(len(ids))
144
  for i, (img_id, sim) in enumerate(zip(ids, similarities)):
145
  with cols[i]:
146
  st.image(images[int(img_id)], caption=f"Sim: {sim:.3f}", width=150)
147
  else:
148
- st.warning("No images in the collection.")
149
  else:
150
- st.info("Please click on one of the options above to load a dataset.")
 
7
  import chromadb
8
  import requests
9
  import tempfile
10
+ import time
11
 
12
  # ----- Setup -----
13
  CACHE_DIR = tempfile.gettempdir()
 
15
  DEMO_DIR = os.path.join(CACHE_DIR, "demo_images")
16
  os.makedirs(DEMO_DIR, exist_ok=True)
17
 
18
+ # ----- Initialize Session State -----
19
+ if 'dataset_loaded' not in st.session_state:
20
+ st.session_state.dataset_loaded = False
21
+ if 'dataset_name' not in st.session_state:
22
+ st.session_state.dataset_name = None
23
+ if 'demo_images' not in st.session_state:
24
+ st.session_state.demo_images = []
25
+ if 'user_images' not in st.session_state:
26
+ st.session_state.user_images = []
27
+
28
  # ----- Load CLIP Model -----
29
  if 'model' not in st.session_state:
30
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
43
  name="user_images", metadata={"hnsw:space": "cosine"}
44
  )
45
 
46
+ # ----- Title -----
47
  st.title("πŸ” CLIP-Based Image Search")
48
 
49
+ # ----- Dataset Buttons -----
50
  col1, col2 = st.columns(2)
51
+ if col1.button("πŸ“¦ Use Demo Images"):
52
+ st.session_state.dataset_name = "demo"
53
+ st.session_state.dataset_loaded = False
54
+
55
+ if col2.button("πŸ“€ Upload Your Images"):
56
+ st.session_state.dataset_name = "user"
57
+ st.session_state.dataset_loaded = False
58
+
59
+ # ----- Download + Embed Demo Images -----
60
+ def download_image_with_retry(url, path, retries=3, delay=1.0):
61
+ for attempt in range(retries):
62
+ try:
63
+ r = requests.get(url, timeout=10)
64
+ if r.status_code == 200:
65
+ with open(path, 'wb') as f:
66
+ f.write(r.content)
67
+ return True
68
+ except Exception as e:
69
+ time.sleep(delay)
70
+ return False
71
+
72
+ if st.session_state.dataset_name == "demo" and not st.session_state.dataset_loaded:
73
  with st.spinner("Downloading and indexing demo images..."):
74
  st.session_state.demo_collection.delete(ids=[str(i) for i in range(50)])
75
+ demo_image_paths, demo_images = [], []
 
 
76
  for i in range(50):
77
  path = os.path.join(DEMO_DIR, f"img_{i+1:02}.jpg")
78
  if not os.path.exists(path):
79
  url = f"https://picsum.photos/seed/{i}/1024/768"
80
+ download_image_with_retry(url, path)
81
+ try:
82
+ demo_images.append(Image.open(path).convert("RGB"))
83
+ demo_image_paths.append(path)
84
+ except:
85
+ continue # skip corrupted
86
 
87
  embeddings, ids, metadatas = [], [], []
88
  for i, img in enumerate(demo_images):
 
95
 
96
  st.session_state.demo_collection.add(embeddings=embeddings, ids=ids, metadatas=metadatas)
97
  st.session_state.demo_images = demo_images
98
+ st.session_state.dataset_loaded = True
 
99
 
100
+ st.success("βœ… Demo images loaded!")
101
 
102
+ # ----- Upload User Images -----
103
+ if st.session_state.dataset_name == "user" and not st.session_state.dataset_loaded:
104
  uploaded = st.file_uploader("Upload your images", type=["jpg", "jpeg", "png"], accept_multiple_files=True)
105
  if uploaded:
106
  st.session_state.user_collection.delete(ids=[
 
108
  ])
109
  user_images = []
110
  for i, file in enumerate(uploaded):
111
+ try:
112
+ img = Image.open(file).convert("RGB")
113
+ except:
114
+ continue
115
  user_images.append(img)
 
116
  img_tensor = st.session_state.preprocess(img).unsqueeze(0).to(st.session_state.device)
117
  with torch.no_grad():
118
  embedding = st.session_state.model.encode_image(img_tensor).cpu().numpy().flatten()
 
121
  )
122
 
123
  st.session_state.user_images = user_images
124
+ st.session_state.dataset_loaded = True
125
+ st.success(f"βœ… Uploaded {len(user_images)} images.")
 
126
 
127
+ # ----- Search Section -----
128
+ if st.session_state.dataset_loaded:
129
+ st.subheader("πŸ”Ž Search Section")
130
  query_type = st.radio("Search by:", ("Text", "Image"))
131
 
132
  query_embedding = None
 
136
  tokens = clip.tokenize([text_query]).to(st.session_state.device)
137
  with torch.no_grad():
138
  query_embedding = st.session_state.model.encode_text(tokens).cpu().numpy().flatten()
139
+
140
+ elif query_type == "Image":
141
+ query_file = st.file_uploader("Upload query image", type=["jpg", "jpeg", "png"], key="query_image")
142
+ if query_file:
143
+ query_img = Image.open(query_file).convert("RGB")
144
  st.image(query_img, caption="Query Image", width=200)
145
+ query_tensor = st.session_state.preprocess(query_img).unsqueeze(0).to(st.session_state.device)
146
  with torch.no_grad():
147
+ query_embedding = st.session_state.model.encode_image(query_tensor).cpu().numpy().flatten()
148
 
149
  # ----- Perform Search -----
150
  if query_embedding is not None:
151
+ if st.session_state.dataset_name == "demo":
152
  collection = st.session_state.demo_collection
153
  images = st.session_state.demo_images
154
  else:
 
164
  distances = results["distances"][0]
165
  similarities = [1 - d for d in distances]
166
 
167
+ st.subheader("πŸ”— Top Matches")
168
  cols = st.columns(len(ids))
169
  for i, (img_id, sim) in enumerate(zip(ids, similarities)):
170
  with cols[i]:
171
  st.image(images[int(img_id)], caption=f"Sim: {sim:.3f}", width=150)
172
  else:
173
+ st.warning("No indexed images to search.")
174
  else:
175
+ st.info("πŸ‘† Please select a dataset (Demo or Upload Images) to begin.")