NEXAS commited on
Commit
d7c7b18
Β·
verified Β·
1 Parent(s): 76c450e

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +92 -93
src/streamlit_app.py CHANGED
@@ -7,30 +7,12 @@ import numpy as np
7
  import chromadb
8
  import requests
9
  import tempfile
10
- from tqdm import tqdm
11
-
12
- # Get a temporary directory (automatically cleaned up after runtime ends)
13
- temp_dir = tempfile.gettempdir()
14
- demo_dir = os.path.join(temp_dir, "demo_images")
15
- os.makedirs(demo_dir, exist_ok=True)
16
-
17
- print(f"Saving images to: {demo_dir}")
18
-
19
- # Download 50 high-resolution images (1024x768)
20
- for i in tqdm(range(50), desc="Downloading images"):
21
- url = f"https://picsum.photos/seed/{i}/1024/768"
22
- response = requests.get(url)
23
- if response.status_code == 200:
24
- with open(os.path.join(demo_dir, f"img_{i+1:02}.jpg"), "wb") as f:
25
- f.write(response.content)
26
- else:
27
- print(f"Failed to download image {i+1}")
28
-
29
-
30
 
31
  # ----- Setup -----
32
  CACHE_DIR = tempfile.gettempdir()
33
  CHROMA_PATH = os.path.join(CACHE_DIR, "chroma_db")
 
 
34
 
35
  # ----- Load CLIP Model -----
36
  if 'model' not in st.session_state:
@@ -39,9 +21,6 @@ if 'model' not in st.session_state:
39
  st.session_state.model = model
40
  st.session_state.preprocess = preprocess
41
  st.session_state.device = device
42
- st.session_state.demo_images = []
43
- st.session_state.demo_image_paths = []
44
- st.session_state.user_images = []
45
 
46
  # ----- Initialize ChromaDB -----
47
  if 'chroma_client' not in st.session_state:
@@ -53,19 +32,36 @@ if 'chroma_client' not in st.session_state:
53
  name="user_images", metadata={"hnsw:space": "cosine"}
54
  )
55
 
56
- # ----- Load Demo Images -----
57
- if not st.session_state.get("demo_images_loaded", False):
58
- demo_folder = "demo_images"
59
- if os.path.exists(demo_folder):
60
- demo_image_paths = [os.path.join(demo_folder, f) for f in os.listdir(demo_folder)
61
- if f.lower().endswith(('.png', '.jpg', '.jpeg'))]
62
- st.session_state.demo_images = [Image.open(p).convert("RGB") for p in demo_image_paths]
63
- st.session_state.demo_image_paths = demo_image_paths
64
-
65
- st.session_state.demo_collection.delete(ids=[str(i) for i in range(len(demo_image_paths))])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
  embeddings, ids, metadatas = [], [], []
68
- for i, img in enumerate(st.session_state.demo_images):
69
  img_tensor = st.session_state.preprocess(img).unsqueeze(0).to(st.session_state.device)
70
  with torch.no_grad():
71
  embedding = st.session_state.model.encode_image(img_tensor).cpu().numpy().flatten()
@@ -74,78 +70,81 @@ if not st.session_state.get("demo_images_loaded", False):
74
  metadatas.append({"path": demo_image_paths[i]})
75
 
76
  st.session_state.demo_collection.add(embeddings=embeddings, ids=ids, metadatas=metadatas)
77
- st.session_state.demo_images_loaded = True
 
 
78
 
79
- # ----- UI -----
80
- st.title("πŸ”Ž CLIP Image Search (Text & Image)")
81
- mode = st.radio("Choose dataset to search in:", ("Demo Images", "My Uploaded Images"))
82
- query_type = st.radio("Query type:", ("Image", "Text"))
83
 
84
- # ----- Upload User Images -----
85
- if mode == "My Uploaded Images":
86
- uploaded = st.file_uploader("Upload your images", type=['jpg', 'jpeg', 'png'], accept_multiple_files=True)
87
  if uploaded:
88
- st.session_state.user_images = []
89
  st.session_state.user_collection.delete(ids=[
90
  str(i) for i in range(st.session_state.user_collection.count())
91
  ])
92
-
93
  for i, file in enumerate(uploaded):
94
  img = Image.open(file).convert("RGB")
95
- st.session_state.user_images.append(img)
96
 
97
  img_tensor = st.session_state.preprocess(img).unsqueeze(0).to(st.session_state.device)
98
  with torch.no_grad():
99
  embedding = st.session_state.model.encode_image(img_tensor).cpu().numpy().flatten()
100
-
101
  st.session_state.user_collection.add(
102
- embeddings=[embedding],
103
- ids=[str(i)],
104
- metadatas=[{"index": i}]
105
  )
106
 
107
- st.success(f"{len(uploaded)} images uploaded.")
108
-
109
- # ----- Perform Query -----
110
- query_embedding = None
111
- if query_type == "Image":
112
- img_file = st.file_uploader("Upload query image", type=["jpg", "jpeg", "png"])
113
- if img_file:
114
- img = Image.open(img_file).convert("RGB")
115
- st.image(img, caption="Query Image", width=200)
116
- img_tensor = st.session_state.preprocess(img).unsqueeze(0).to(st.session_state.device)
117
- with torch.no_grad():
118
- query_embedding = st.session_state.model.encode_image(img_tensor).cpu().numpy().flatten()
119
- elif query_type == "Text":
120
- text_query = st.text_input("Enter search text:")
121
- if text_query:
122
- tokens = clip.tokenize([text_query]).to(st.session_state.device)
123
- with torch.no_grad():
124
- query_embedding = st.session_state.model.encode_text(tokens).cpu().numpy().flatten()
125
-
126
- # ----- Run Search -----
127
- if query_embedding is not None:
128
- if mode == "Demo Images":
129
- collection = st.session_state.demo_collection
130
- images = st.session_state.demo_images
131
- else:
132
- collection = st.session_state.user_collection
133
- images = st.session_state.user_images
134
-
135
- if collection.count() > 0:
136
- results = collection.query(
137
- query_embeddings=[query_embedding],
138
- n_results=min(5, collection.count())
139
- )
140
- ids = results["ids"][0]
141
- distances = results["distances"][0]
142
- similarities = [1 - d for d in distances]
143
-
144
- st.subheader("Top Matches")
145
- cols = st.columns(5)
146
- for i, (img_id, sim) in enumerate(zip(ids, similarities)):
147
- with cols[i]:
148
- idx = int(img_id)
149
- st.image(images[idx], caption=f"Sim: {sim:.3f}", width=150)
150
  else:
151
- st.warning("No images found in collection.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  import chromadb
8
  import requests
9
  import tempfile
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
 
11
  # ----- Setup -----
12
  CACHE_DIR = tempfile.gettempdir()
13
  CHROMA_PATH = os.path.join(CACHE_DIR, "chroma_db")
14
+ DEMO_DIR = os.path.join(CACHE_DIR, "demo_images")
15
+ os.makedirs(DEMO_DIR, exist_ok=True)
16
 
17
  # ----- Load CLIP Model -----
18
  if 'model' not in st.session_state:
 
21
  st.session_state.model = model
22
  st.session_state.preprocess = preprocess
23
  st.session_state.device = device
 
 
 
24
 
25
  # ----- Initialize ChromaDB -----
26
  if 'chroma_client' not in st.session_state:
 
32
  name="user_images", metadata={"hnsw:space": "cosine"}
33
  )
34
 
35
+ st.title("πŸ” CLIP-Based Image Search")
36
+
37
+ # Dataset selection
38
+ col1, col2 = st.columns(2)
39
+ use_demo = col1.button("πŸ“¦ Use Demo Images")
40
+ upload_own = col2.button("πŸ“€ Upload Your Images")
41
+
42
+ dataset_loaded = False
43
+ dataset_name = None
44
+
45
+ # ----- Handle Demo Images -----
46
+ if use_demo:
47
+ with st.spinner("Downloading and indexing demo images..."):
48
+ st.session_state.demo_collection.delete(ids=[str(i) for i in range(50)])
49
+ demo_image_paths = []
50
+ demo_images = []
51
+
52
+ for i in range(50):
53
+ path = os.path.join(DEMO_DIR, f"img_{i+1:02}.jpg")
54
+ if not os.path.exists(path):
55
+ url = f"https://picsum.photos/seed/{i}/1024/768"
56
+ response = requests.get(url)
57
+ if response.status_code == 200:
58
+ with open(path, "wb") as f:
59
+ f.write(response.content)
60
+ demo_image_paths.append(path)
61
+ demo_images.append(Image.open(path).convert("RGB"))
62
 
63
  embeddings, ids, metadatas = [], [], []
64
+ for i, img in enumerate(demo_images):
65
  img_tensor = st.session_state.preprocess(img).unsqueeze(0).to(st.session_state.device)
66
  with torch.no_grad():
67
  embedding = st.session_state.model.encode_image(img_tensor).cpu().numpy().flatten()
 
70
  metadatas.append({"path": demo_image_paths[i]})
71
 
72
  st.session_state.demo_collection.add(embeddings=embeddings, ids=ids, metadatas=metadatas)
73
+ st.session_state.demo_images = demo_images
74
+ dataset_loaded = True
75
+ dataset_name = "demo"
76
 
77
+ st.success("Demo images loaded!")
 
 
 
78
 
79
+ # ----- Handle User Uploads -----
80
+ if upload_own:
81
+ uploaded = st.file_uploader("Upload your images", type=["jpg", "jpeg", "png"], accept_multiple_files=True)
82
  if uploaded:
 
83
  st.session_state.user_collection.delete(ids=[
84
  str(i) for i in range(st.session_state.user_collection.count())
85
  ])
86
+ user_images = []
87
  for i, file in enumerate(uploaded):
88
  img = Image.open(file).convert("RGB")
89
+ user_images.append(img)
90
 
91
  img_tensor = st.session_state.preprocess(img).unsqueeze(0).to(st.session_state.device)
92
  with torch.no_grad():
93
  embedding = st.session_state.model.encode_image(img_tensor).cpu().numpy().flatten()
 
94
  st.session_state.user_collection.add(
95
+ embeddings=[embedding], ids=[str(i)], metadatas=[{"index": i}]
 
 
96
  )
97
 
98
+ st.session_state.user_images = user_images
99
+ st.success(f"{len(user_images)} images uploaded.")
100
+ dataset_loaded = True
101
+ dataset_name = "user"
102
+
103
+ # ----- Search UI -----
104
+ if dataset_loaded:
105
+ st.subheader("Search Section")
106
+ query_type = st.radio("Search by:", ("Text", "Image"))
107
+
108
+ query_embedding = None
109
+ if query_type == "Text":
110
+ text_query = st.text_input("Enter search text:")
111
+ if text_query:
112
+ tokens = clip.tokenize([text_query]).to(st.session_state.device)
113
+ with torch.no_grad():
114
+ query_embedding = st.session_state.model.encode_text(tokens).cpu().numpy().flatten()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
  else:
116
+ img_file = st.file_uploader("Upload query image", type=["jpg", "jpeg", "png"])
117
+ if img_file:
118
+ query_img = Image.open(img_file).convert("RGB")
119
+ st.image(query_img, caption="Query Image", width=200)
120
+ img_tensor = st.session_state.preprocess(query_img).unsqueeze(0).to(st.session_state.device)
121
+ with torch.no_grad():
122
+ query_embedding = st.session_state.model.encode_image(img_tensor).cpu().numpy().flatten()
123
+
124
+ # ----- Perform Search -----
125
+ if query_embedding is not None:
126
+ if dataset_name == "demo":
127
+ collection = st.session_state.demo_collection
128
+ images = st.session_state.demo_images
129
+ else:
130
+ collection = st.session_state.user_collection
131
+ images = st.session_state.user_images
132
+
133
+ if collection.count() > 0:
134
+ results = collection.query(
135
+ query_embeddings=[query_embedding],
136
+ n_results=min(5, collection.count())
137
+ )
138
+ ids = results["ids"][0]
139
+ distances = results["distances"][0]
140
+ similarities = [1 - d for d in distances]
141
+
142
+ st.subheader("Top Matches")
143
+ cols = st.columns(len(ids))
144
+ for i, (img_id, sim) in enumerate(zip(ids, similarities)):
145
+ with cols[i]:
146
+ st.image(images[int(img_id)], caption=f"Sim: {sim:.3f}", width=150)
147
+ else:
148
+ st.warning("No images in the collection.")
149
+ else:
150
+ st.info("Please click on one of the options above to load a dataset.")