NEXAS commited on
Commit
5969029
·
verified ·
1 Parent(s): 3d041f0

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +99 -206
src/streamlit_app.py CHANGED
@@ -1,210 +1,103 @@
1
- import os
2
- import uuid
3
- import fitz
4
- import tempfile
5
  import streamlit as st
 
 
6
  from PIL import Image
 
7
  import numpy as np
8
- from skimage.io import imsave
9
- from torchvision.datasets import CIFAR10
10
- import torchvision.transforms as T
11
- import logging
12
-
13
- # Set up logging
14
- logging.basicConfig(level=logging.DEBUG)
15
- logger = logging.getLogger(__name__)
16
-
17
- # Set HuggingFace cache directory
18
- HF_CACHE = os.path.join(tempfile.gettempdir(), "hf_cache")
19
- os.makedirs(HF_CACHE, exist_ok=True)
20
- os.environ["XDG_CACHE_HOME"] = HF_CACHE
21
- os.environ["HF_HOME"] = HF_CACHE
22
- # Add HuggingFace token if needed
23
- # os.environ["HF_TOKEN"] = "your-huggingface-api-token"
24
-
25
- from chromadb import PersistentClient
26
- from chromadb.utils.data_loaders import ImageLoader
27
- from chromadb.utils.embedding_functions import OpenCLIPEmbeddingFunction
28
-
29
- # Paths
30
- TEMP_DIR = tempfile.gettempdir()
31
- IMAGES_DIR = os.path.join(TEMP_DIR, "extracted_images")
32
- DB_PATH = os.path.join(TEMP_DIR, "image_vdb")
33
- os.makedirs(IMAGES_DIR, exist_ok=True)
34
- os.makedirs(DB_PATH, exist_ok=True)
35
-
36
- # Init ChromaDB collection
37
- @st.cache_resource
38
- def get_chroma_collection():
39
- try:
40
- chroma_client = PersistentClient(path=DB_PATH)
41
- image_loader = ImageLoader()
42
- embedding_fn = OpenCLIPEmbeddingFunction()
43
- collection = chroma_client.get_or_create_collection(
44
- name="image", embedding_function=embedding_fn, data_loader=image_loader
45
- )
46
- return collection
47
- except Exception as e:
48
- logger.error(f"Error initializing ChromaDB: {e}")
49
- st.error(f"Failed to initialize ChromaDB: {e}")
50
- return None
51
-
52
- image_collection = get_chroma_collection()
53
- if image_collection is None:
54
- st.stop()
55
-
56
- # --- Extract images from PDF ---
57
- def extract_images_from_pdf(pdf_bytes):
58
- try:
59
- pdf = fitz.open(stream=pdf_bytes, filetype="pdf")
60
- saved = []
61
- for i in range(len(pdf)):
62
- for img in pdf.load_page(i).get_images(full=True):
63
- base = pdf.extract_image(img[0])
64
- ext = base["ext"]
65
- path = os.path.join(IMAGES_DIR, f"pdf_p{i+1}_img{img[0]}.{ext}")
66
- with open(path, "wb") as f:
67
- f.write(base["image"])
68
- saved.append(path)
69
- return saved
70
- except Exception as e:
71
- logger.error(f"Error extracting images from PDF: {e}")
72
- st.error(f"Failed to extract images: {e}")
73
- return []
74
-
75
- # --- Index images ---
76
- def index_images(paths):
77
- try:
78
- ids, uris = [], []
79
- for path in sorted(paths):
80
- if path.lower().endswith((".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".webp")):
81
- ids.append(str(uuid.uuid4()))
82
- uris.append(path)
83
- if ids:
84
- image_collection.add(ids=ids, uris=uris)
85
- except Exception as e:
86
- logger.error(f"Error indexing images: {e}")
87
- st.error(f"Failed to index images: {e}")
88
-
89
- # --- Image-to-Image search ---
90
- def query_similar_images(image_file, top_k=5):
91
- try:
92
- with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as tmp:
93
- tmp.write(image_file.read())
94
- tmp.flush()
95
- res = image_collection.query(query_uris=[tmp.name], n_results=top_k)
96
- os.remove(tmp.name)
97
- if not res or 'uris' not in res or not res['uris'] or not res['uris'][0]:
98
- return []
99
- return res['uris'][0]
100
- except Exception as e:
101
- logger.error(f"Error in image-to-image search: {e}")
102
- st.error(f"Failed to perform image search: {e}")
103
- return []
104
-
105
- # --- Text-to-Image search ---
106
- def search_images_by_text(text, top_k=5):
107
- try:
108
- res = image_collection.query(query_texts=[text], n_results=top_k)
109
- if not res or 'uris' not in res or not res['uris'] or not res['uris'][0]:
110
- return []
111
- return res['uris'][0]
112
- except Exception as e:
113
- logger.error(f"Error in text-to-image search: {e}")
114
- st.error(f"Failed to perform text search: {e}")
115
- return []
116
-
117
- # --- Load CIFAR-10 Demo Dataset (500 images) ---
118
- @st.cache_resource
119
- def load_demo_cifar10(n=500):
120
- try:
121
- dataset = CIFAR10(root=TEMP_DIR, download=True, train=True)
122
- transform = T.ToPILImage()
123
- saved = []
124
-
125
- progress_bar = st.progress(0)
126
- for i in range(min(n, len(dataset))):
127
- img, label = dataset[i]
128
- if not isinstance(img, Image.Image):
129
- img = transform(img)
130
- path = os.path.join(IMAGES_DIR, f"cifar10_{i}_{label}.png")
131
- img.save(path)
132
- saved.append(path)
133
- if i % 10 == 0 or i == n - 1:
134
- progress_bar.progress((i + 1) / n)
135
- return saved
136
- except Exception as e:
137
- logger.error(f"Error loading CIFAR-10 dataset: {e}")
138
- st.error(f"Failed to load CIFAR-10 dataset: {e}")
139
- return []
140
-
141
- # === UI START ===
142
- st.title("🔍 Semantic Image Search App")
143
-
144
- # Step 1: Load data
145
- choice = st.radio("���� Select Image Source", ["Upload PDF", "Upload Images", "Load CIFAR‑10 Demo"], horizontal=True)
146
-
147
- if choice == "Upload PDF":
148
- pdf = st.file_uploader("📤 Upload PDF file", type=["pdf"])
149
- if pdf:
150
- with st.spinner("Extracting images from PDF..."):
151
- imgs = extract_images_from_pdf(pdf.read())
152
- if imgs:
153
- index_images(imgs)
154
- st.success(f"✅ Indexed {len(imgs)} images from PDF.")
155
- st.image(imgs, width=120)
156
- else:
157
- st.warning("No images extracted from PDF.")
158
-
159
- elif choice == "Upload Images":
160
- imgs = st.file_uploader("📤 Upload image files", type=["jpg", "jpeg", "png", "bmp", "tiff", "webp"], accept_multiple_files=True)
161
- if imgs:
162
- with st.spinner("Indexing uploaded images..."):
163
- paths = []
164
- for item in imgs:
165
- p = os.path.join(IMAGES_DIR, item.name)
166
- with open(p, "wb") as f:
167
- f.write(item.read())
168
- paths.append(p)
169
- index_images(paths)
170
- st.success(f"✅ {len(paths)} images indexed.")
171
- st.image(paths, width=120)
172
-
173
- elif choice == "Load CIFAR‑10 Demo":
174
- if st.button("🔄 Load 500 CIFAR‑10 Images"):
175
- with st.spinner("Loading CIFAR‑10 demo dataset..."):
176
- paths = load_demo_cifar10(500)
177
- if paths:
178
- index_images(paths)
179
- st.success("✅ 500 demo images loaded and indexed.")
180
- st.image(paths[:20], width=100)
181
- else:
182
- st.warning("Failed to load CIFAR-10 images.")
183
-
184
- # Step 2: Search
185
- st.divider()
186
- st.subheader("🖼️ Image-to-Image Search")
187
- q = st.file_uploader("📷 Upload a query image", type=["jpg", "jpeg", "png", "bmp", "tiff", "webp"])
188
- if q:
189
- st.image(q, caption="Query Image", width=200)
190
- with st.spinner("Finding similar images..."):
191
- results = query_similar_images(q, top_k=5)
192
- if not results:
193
- st.warning("No similar images found.")
194
- else:
195
- st.subheader("🔁 Top Matches:")
196
- for u in results:
197
- st.image(u, width=150)
198
 
199
- st.divider()
200
- st.subheader("📝 Text-to-Image Search")
201
- txt = st.text_input("Describe what you’re looking for (e.g., 'a beach', 'a cat', 'a red truck'):")
202
- if txt:
203
- with st.spinner("Finding images by semantic similarity..."):
204
- results = search_images_by_text(txt, top_k=5)
205
- if not results:
206
- st.warning("No semantic matches found.")
207
- else:
208
- st.subheader("🔍 Semantic Matches:")
209
- for u in results:
210
- st.image(u, width=150)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ import torch
3
+ import clip
4
  from PIL import Image
5
+ import os
6
  import numpy as np
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
 
8
+ # Initialize session state
9
+ if 'model' not in st.session_state:
10
+ device = "cuda" if torch.cuda.is_available() else "cpu"
11
+ model, preprocess = clip.load("ViT-B/32", device=device)
12
+ st.session_state.model = model
13
+ st.session_state.preprocess = preprocess
14
+ st.session_state.device = device
15
+ st.session_state.demo_images = []
16
+ st.session_state.demo_encodings = []
17
+ st.session_state.demo_image_paths = []
18
+ st.session_state.user_images = []
19
+ st.session_state.user_encodings = []
20
+
21
+ # Load demo images
22
+ if not st.session_state.demo_images:
23
+ demo_folder = "demo_images"
24
+ if os.path.exists(demo_folder):
25
+ demo_image_paths = [os.path.join(demo_folder, f) for f in os.listdir(demo_folder) if f.endswith(('.png', '.jpg', '.jpeg'))]
26
+ if len(demo_image_paths) > 0:
27
+ st.session_state.demo_image_paths = demo_image_paths
28
+ st.session_state.demo_images = [Image.open(path) for path in demo_image_paths]
29
+ demo_preprocessed = [st.session_state.preprocess(img).unsqueeze(0).to(st.session_state.device) for img in st.session_state.demo_images]
30
+ with torch.no_grad():
31
+ demo_encodings = [st.session_state.model.encode_image(img) for img in demo_preprocessed]
32
+ st.session_state.demo_encodings = torch.cat(demo_encodings, dim=0)
33
+ else:
34
+ st.warning("No images found in 'demo_images' folder. Demo mode will be limited.")
35
+
36
+ # Streamlit UI
37
+ st.title("Image Search with CLIP")
38
+
39
+ # Mode selection
40
+ mode = st.radio("Select mode", ("Search in Demo Images", "Search in My Images"))
41
+
42
+ # User images upload
43
+ if mode == "Search in My Images":
44
+ st.subheader("Upload Your Images")
45
+ uploaded_files = st.file_uploader("Choose images", type=['png', 'jpg', 'jpeg'], accept_multiple_files=True)
46
+
47
+ if uploaded_files:
48
+ # Clear previous user images to avoid duplicates
49
+ st.session_state.user_images = []
50
+ st.session_state.user_encodings = []
51
+
52
+ for uploaded_file in uploaded_files:
53
+ img = Image.open(uploaded_file)
54
+ st.session_state.user_images.append(img)
55
+ img_pre = st.session_state.preprocess(img).unsqueeze(0).to(st.session_state.device)
56
+ with torch.no_grad():
57
+ encoding = st.session_state.model.encode_image(img_pre)
58
+ st.session_state.user_encodings.append(encoding)
59
+
60
+ if st.session_state.user_encodings:
61
+ st.session_state.user_encodings = torch.cat(st.session_state.user_encodings, dim=0)
62
+ st.success(f"Uploaded {len(st.session_state.user_images)} images successfully.")
63
+ else:
64
+ st.warning("No images uploaded yet.")
65
+
66
+ # Query image upload
67
+ st.subheader("Upload Query Image")
68
+ query_file = st.file_uploader("Choose a query image", type=['png', 'jpg', 'jpeg'])
69
+
70
+ if query_file is not None:
71
+ query_img = Image.open(query_file)
72
+ st.image(query_img, caption="Query Image", width=200)
73
+ query_pre = st.session_state.preprocess(query_img).unsqueeze(0).
74
+
75
+ to(st.session_state.device)
76
+ with torch.no_grad():
77
+ query_encoding = st.session_state.model.encode_image(query_pre)
78
+
79
+ if mode == "Search in Demo Images":
80
+ if st.session_state.demo_encodings is not None and len(st.session_state.demo_encodings) > 0:
81
+ similarities = (st.session_state.demo_encodings @ query_encoding.T).squeeze()
82
+ top_indices = torch.topk(similarities, min(5, len(similarities))).indices.cpu().numpy()
83
+
84
+ st.subheader("Top 5 Similar Images")
85
+ cols = st.columns(5)
86
+ for i, idx in enumerate(top_indices):
87
+ with cols[i]:
88
+ st.image(st.session_state.demo_images[idx], caption=f"Similarity: {similarities[idx]:.4f}", width=150)
89
+ else:
90
+ st.error("No demo images available. Please check the 'demo_images' folder.")
91
+
92
+ elif mode == "Search in My Images":
93
+ if st.session_state.user_encodings is not None and len(st.session_state.user_encodings) > 0:
94
+ similarities = (st.session_state.user_encodings @ query_encoding.T).squeeze()
95
+ top_indices = torch.topk(similarities, min(5, len(similarities))).indices.cpu().numpy()
96
+
97
+ st.subheader("Top 5 Similar Images")
98
+ cols = st.columns(5)
99
+ for i, idx in enumerate(top_indices):
100
+ with cols[i]:
101
+ st.image(st.session_state.user_images[idx], caption=f"Similarity: {similarities[idx]:.4f}", width=150)
102
+ else:
103
+ st.error("No user images uploaded yet. Please upload images first.")