NEXAS commited on
Commit
60c342d
·
verified ·
1 Parent(s): 43cc2f2

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +97 -121
src/streamlit_app.py CHANGED
@@ -1,28 +1,25 @@
1
  import os
2
- import tempfile
3
-
4
- # Set cache directory to temp or app folder
5
- cache_dir = os.path.join(tempfile.gettempdir(), "hf_cache")
6
- os.makedirs(cache_dir, exist_ok=True)
7
-
8
- os.environ["XDG_CACHE_HOME"] = cache_dir
9
- os.environ["HF_HOME"] = cache_dir
10
-
11
- # Now import OpenCLIPEmbeddingFunction
12
- from chromadb.utils.embedding_functions import OpenCLIPEmbeddingFunction
13
-
14
  import fitz
15
  import tempfile
16
  import streamlit as st
17
  from PIL import Image
 
 
 
 
 
 
 
 
 
 
 
18
  from chromadb import PersistentClient
19
  from chromadb.utils.data_loaders import ImageLoader
20
  from chromadb.utils.embedding_functions import OpenCLIPEmbeddingFunction
21
- from skimage import data as skdata
22
- from skimage.io import imsave
23
- import uuid
24
 
25
- # Use safe temp directories for Streamlit or restricted environments
26
  TEMP_DIR = tempfile.gettempdir()
27
  IMAGES_DIR = os.path.join(TEMP_DIR, "extracted_images")
28
  DB_PATH = os.path.join(TEMP_DIR, "image_vdb")
@@ -40,122 +37,101 @@ def get_chroma_collection():
40
 
41
  image_collection = get_chroma_collection()
42
 
43
- # === Image Extraction ===
44
  def extract_images_from_pdf(pdf_bytes):
45
  pdf = fitz.open(stream=pdf_bytes, filetype="pdf")
46
- saved_images = []
47
-
48
- for page_num in range(len(pdf)):
49
- page = pdf.load_page(page_num)
50
- images = page.get_images(full=True)
51
-
52
- for img_idx, img in enumerate(images):
53
- xref = img[0]
54
- base_image = pdf.extract_image(xref)
55
- img_bytes = base_image["image"]
56
- ext = base_image["ext"]
57
- filename = f"page_{page_num+1}_img_{img_idx+1}.{ext}"
58
- path = os.path.join(IMAGES_DIR, filename)
59
-
60
- with open(path, "wb") as f:
61
- f.write(img_bytes)
62
-
63
- saved_images.append(path)
64
-
65
- return saved_images
66
-
67
- # === Indexing ===
68
- def index_images(image_paths):
69
- ids = []
70
- uris = []
71
- for path in sorted(image_paths):
72
- if path.lower().endswith((".png", ".jpeg", ".jpg")):
73
  ids.append(str(uuid.uuid4()))
74
  uris.append(path)
75
-
76
  if ids:
77
  image_collection.add(ids=ids, uris=uris)
78
 
79
- # === Querying ===
80
  def query_similar_images(image_file, top_k=5):
81
  with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as tmp:
82
  tmp.write(image_file.read())
83
- tmp_path = tmp.name
84
-
85
- try:
86
- results = image_collection.query(query_uris=[tmp_path], n_results=top_k)
87
- return results['uris'][0]
88
- finally:
89
- os.remove(tmp_path)
90
-
91
- # === Demo images ===
92
- def load_skimage_demo_images():
93
- demo_images = {
94
- "astronaut": skdata.astronaut(),
95
- "coffee": skdata.coffee(),
96
- "camera": skdata.camera(),
97
- "chelsea": skdata.chelsea(),
98
- "rocket": skdata.rocket()
99
- }
100
- saved_paths = []
101
-
102
- for name, img in demo_images.items():
103
- path = os.path.join(IMAGES_DIR, f"{name}.png")
104
- imsave(path, img)
105
- saved_paths.append(path)
106
-
107
- return saved_paths
108
-
109
- # === Streamlit UI ===
110
- st.title("🔍 Image Similarity Search from PDF or Custom Dataset")
111
-
112
- source = st.radio(
113
- "Select Image Source",
114
- ["Upload PDF", "Upload Images", "Load Demo Dataset"],
115
- horizontal=True
116
- )
117
-
118
- if source == "Upload PDF":
119
- uploaded_pdf = st.file_uploader("📤 Upload PDF", type=["pdf"])
120
- if uploaded_pdf:
121
- with st.spinner("Extracting images..."):
122
- images = extract_images_from_pdf(uploaded_pdf.read())
123
- index_images(images)
124
- st.success(f"{len(images)} images extracted and indexed.")
125
- st.image(images, width=150)
126
-
127
- elif source == "Upload Images":
128
- uploaded_imgs = st.file_uploader(
129
- "📤 Upload one or more images", type=["jpg", "jpeg", "png"], accept_multiple_files=True
130
- )
131
- if uploaded_imgs:
132
- saved_paths = []
133
- for img in uploaded_imgs:
134
- img_path = os.path.join(IMAGES_DIR, img.name)
135
- with open(img_path, "wb") as f:
136
- f.write(img.read())
137
- saved_paths.append(img_path)
138
-
139
- index_images(saved_paths)
140
- st.success(f"{len(saved_paths)} images indexed.")
141
- st.image(saved_paths, width=150)
142
-
143
- elif source == "Load Demo Dataset":
144
- if st.button("🔄 Load Demo Images (skimage)"):
145
- demo_paths = load_skimage_demo_images()
146
- index_images(demo_paths)
147
- st.success("Demo images loaded and indexed.")
148
- st.image(demo_paths, width=150)
149
 
150
- st.divider()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
 
152
- st.subheader("🔎 Search for Similar Images")
153
- query_img = st.file_uploader("Upload a query image", type=["jpg", "jpeg", "png"])
154
- if query_img:
155
- st.image(query_img, caption="Query Image", width=200)
 
156
  with st.spinner("Searching..."):
157
- matches = query_similar_images(query_img, top_k=5)
 
 
158
 
159
- st.subheader("📊 Top Matches:")
160
- for match in matches:
161
- st.image(match, width=200, caption=os.path.basename(match))
 
 
 
 
 
 
1
  import os
2
+ import uuid
 
 
 
 
 
 
 
 
 
 
 
3
  import fitz
4
  import tempfile
5
  import streamlit as st
6
  from PIL import Image
7
+ import numpy as np
8
+ from skimage.io import imsave
9
+ from torchvision.datasets import CIFAR10
10
+ import torchvision.transforms as T
11
+
12
+ # Setup cache paths
13
+ HF_CACHE = os.path.join(tempfile.gettempdir(), "hf_cache")
14
+ os.makedirs(HF_CACHE, exist_ok=True)
15
+ os.environ["XDG_CACHE_HOME"] = HF_CACHE
16
+ os.environ["HF_HOME"] = HF_CACHE
17
+
18
  from chromadb import PersistentClient
19
  from chromadb.utils.data_loaders import ImageLoader
20
  from chromadb.utils.embedding_functions import OpenCLIPEmbeddingFunction
 
 
 
21
 
22
+ # Directories
23
  TEMP_DIR = tempfile.gettempdir()
24
  IMAGES_DIR = os.path.join(TEMP_DIR, "extracted_images")
25
  DB_PATH = os.path.join(TEMP_DIR, "image_vdb")
 
37
 
38
  image_collection = get_chroma_collection()
39
 
40
+ # PDFs & Uploads —
41
  def extract_images_from_pdf(pdf_bytes):
42
  pdf = fitz.open(stream=pdf_bytes, filetype="pdf")
43
+ saved = []
44
+ for i in range(len(pdf)):
45
+ for img in pdf.load_page(i).get_images(full=True):
46
+ base = pdf.extract_image(img[0])
47
+ ext = base["ext"]
48
+ path = os.path.join(IMAGES_DIR, f"pdf_p{i+1}_img{img[0]}.{ext}")
49
+ with open(path,"wb") as f: f.write(base["image"])
50
+ saved.append(path)
51
+ return saved
52
+
53
+ def index_images(paths):
54
+ ids, uris = [], []
55
+ for path in sorted(paths):
56
+ if path.lower().endswith((".jpg",".jpeg",".png")):
 
 
 
 
 
 
 
 
 
 
 
 
 
57
  ids.append(str(uuid.uuid4()))
58
  uris.append(path)
 
59
  if ids:
60
  image_collection.add(ids=ids, uris=uris)
61
 
62
+ # Queries
63
  def query_similar_images(image_file, top_k=5):
64
  with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as tmp:
65
  tmp.write(image_file.read())
66
+ tmp.flush()
67
+ res = image_collection.query(query_uris=[tmp.name], n_results=top_k)
68
+ os.remove(tmp.name)
69
+ return res['uris'][0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
+ def search_images_by_text(text, top_k=5):
72
+ res = image_collection.query(query_texts=[text], n_results=top_k)
73
+ return res['uris'][0]
74
+
75
+ # — Demo Dataset: CIFAR10 (500 images) —
76
+ @st.cache_resource
77
+ def load_demo_cifar10(n=500):
78
+ dataset = CIFAR10(root=TEMP_DIR, download=True, train=True)
79
+ transform = T.ToPILImage()
80
+ saved = []
81
+ for i in range(min(n, len(dataset))):
82
+ img, label = dataset[i]
83
+ if not isinstance(img, Image.Image):
84
+ img = transform(img)
85
+ path = os.path.join(IMAGES_DIR, f"cifar10_{i}_{label}.png")
86
+ img.save(path)
87
+ saved.append(path)
88
+ return saved
89
+
90
+ # — UI Starts —
91
+ st.title("🔍 Image & Text Similarity Search with 500‑Image Demo DB")
92
+
93
+ choice = st.radio("Select data source", ["Upload PDF", "Upload Images", "Load CIFAR‑10 Demo"], horizontal=True)
94
+
95
+ if choice=="Upload PDF":
96
+ pdf = st.file_uploader("📤 Upload PDF", type=["pdf"])
97
+ if pdf:
98
+ with st.spinner("Extracting..."):
99
+ imgs = extract_images_from_pdf(pdf.read()); index_images(imgs)
100
+ st.success(f"{len(imgs)} images indexed from PDF")
101
+ st.image(imgs, width=120)
102
+
103
+ elif choice=="Upload Images":
104
+ imgs = st.file_uploader("📤 Upload images", accept_multiple_files=True, type=["jpg","jpeg","png"])
105
+ if imgs:
106
+ paths=[]
107
+ for item in imgs:
108
+ p=os.path.join(IMAGES_DIR, item.name)
109
+ with open(p,"wb") as f: f.write(item.read()); paths.append(p)
110
+ index_images(paths)
111
+ st.success(f"{len(paths)} images uploaded & indexed")
112
+ st.image(paths, width=120)
113
+
114
+ elif choice=="Load CIFAR‑10 Demo":
115
+ if st.button("🔄 Load 500 CIFAR‑10 Images"):
116
+ paths=load_demo_cifar10(500); index_images(paths)
117
+ st.success("500 CIFAR‑10 demo images loaded and indexed")
118
+ st.image(paths[:20], width=100)
119
 
120
+ st.divider()
121
+ st.subheader("🔎 Image-Based Search")
122
+ q = st.file_uploader("Upload a query image", type=["jpg","jpeg","png"])
123
+ if q:
124
+ st.image(q, caption="Query");
125
  with st.spinner("Searching..."):
126
+ out = query_similar_images(q, top_k=5)
127
+ st.subheader("Top Image Matches")
128
+ for u in out: st.image(u, width=150)
129
 
130
+ st.divider()
131
+ st.subheader("📝 Text-to-Image Semantic Search")
132
+ txt = st.text_input("Enter description (e.g. 'a beach'):")
133
+ if txt:
134
+ with st.spinner("Searching..."):
135
+ out = search_images_by_text(txt, top_k=5)
136
+ st.subheader("Top Semantic Matches")
137
+ for u in out: st.image(u, width=150)