NEXAS commited on
Commit
d5eb2b5
Β·
verified Β·
1 Parent(s): 6401b00

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +77 -28
src/streamlit_app.py CHANGED
@@ -1,26 +1,29 @@
1
  import os
2
- import fitz # PyMuPDF
3
- import chromadb
4
  import tempfile
5
  import streamlit as st
6
  from PIL import Image
 
7
  from chromadb.utils.data_loaders import ImageLoader
8
  from chromadb.utils.embedding_functions import OpenCLIPEmbeddingFunction
 
 
 
9
 
10
  # Paths
11
  DB_PATH = './data/image_vdb'
12
  IMAGES_DIR = './data/extracted_images'
13
  os.makedirs(IMAGES_DIR, exist_ok=True)
14
 
15
- # Init Chroma
16
- chroma_client = chromadb.PersistentClient(path=DB_PATH)
17
  image_loader = ImageLoader()
18
  embedding_fn = OpenCLIPEmbeddingFunction()
19
  image_collection = chroma_client.get_or_create_collection(
20
  name="image", embedding_function=embedding_fn, data_loader=image_loader
21
  )
22
 
23
- # Utilities
24
  def extract_images_from_pdf(pdf_bytes):
25
  pdf = fitz.open(stream=pdf_bytes, filetype="pdf")
26
  saved_images = []
@@ -44,13 +47,12 @@ def extract_images_from_pdf(pdf_bytes):
44
 
45
  return saved_images
46
 
47
- def index_images_in_chroma(image_paths):
48
  ids = []
49
  uris = []
50
-
51
  for i, path in enumerate(sorted(image_paths)):
52
  if path.endswith((".png", ".jpeg", ".jpg")):
53
- ids.append(f"img_{len(image_collection.get()['ids']) + i}")
54
  uris.append(path)
55
 
56
  if ids:
@@ -65,27 +67,74 @@ def query_similar_images(image_file, top_k=5):
65
  os.remove(tmp_path)
66
  return results['uris'][0]
67
 
68
- # Streamlit UI
69
- st.title("πŸ” Image Search from PDF (HR Tool Demo)")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
70
 
71
- with st.expander("πŸ“€ Step 1: Upload PDF to Extract Images"):
72
- uploaded_pdf = st.file_uploader("Upload a PDF file", type=["pdf"])
73
- if uploaded_pdf is not None:
74
  with st.spinner("Extracting images..."):
75
- saved_images = extract_images_from_pdf(uploaded_pdf.read())
76
- index_images_in_chroma(saved_images)
77
- st.success(f"Extracted and indexed {len(saved_images)} images.")
78
- st.image(saved_images, caption="Extracted images", width=150)
79
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  st.divider()
81
 
82
- with st.expander("πŸ–ΌοΈ Step 2: Search by Uploading a Query Image"):
83
- query_img = st.file_uploader("Upload a query image", type=["jpg", "jpeg", "png"])
84
- if query_img is not None:
85
- st.image(query_img, caption="Query Image", width=200)
86
- with st.spinner("Searching similar images..."):
87
- results = query_similar_images(query_img, top_k=5)
88
-
89
- st.subheader("πŸ”Ž Top Matches:")
90
- for res_path in results:
91
- st.image(res_path, width=200, caption=os.path.basename(res_path))
 
 
1
  import os
2
+ import fitz
 
3
  import tempfile
4
  import streamlit as st
5
  from PIL import Image
6
+ from chromadb import PersistentClient
7
  from chromadb.utils.data_loaders import ImageLoader
8
  from chromadb.utils.embedding_functions import OpenCLIPEmbeddingFunction
9
+ from skimage import data as skdata
10
+ from skimage.io import imsave
11
+ import uuid
12
 
13
  # Paths
14
  DB_PATH = './data/image_vdb'
15
  IMAGES_DIR = './data/extracted_images'
16
  os.makedirs(IMAGES_DIR, exist_ok=True)
17
 
18
+ # Init ChromaDB
19
+ chroma_client = PersistentClient(path=DB_PATH)
20
  image_loader = ImageLoader()
21
  embedding_fn = OpenCLIPEmbeddingFunction()
22
  image_collection = chroma_client.get_or_create_collection(
23
  name="image", embedding_function=embedding_fn, data_loader=image_loader
24
  )
25
 
26
+ # === Image Handling ===
27
  def extract_images_from_pdf(pdf_bytes):
28
  pdf = fitz.open(stream=pdf_bytes, filetype="pdf")
29
  saved_images = []
 
47
 
48
  return saved_images
49
 
50
+ def index_images(image_paths):
51
  ids = []
52
  uris = []
 
53
  for i, path in enumerate(sorted(image_paths)):
54
  if path.endswith((".png", ".jpeg", ".jpg")):
55
+ ids.append(str(uuid.uuid4()))
56
  uris.append(path)
57
 
58
  if ids:
 
67
  os.remove(tmp_path)
68
  return results['uris'][0]
69
 
70
+ def load_skimage_demo_images():
71
+ demo_images = {
72
+ "astronaut": skdata.astronaut(),
73
+ "coffee": skdata.coffee(),
74
+ "camera": skdata.camera(),
75
+ "chelsea": skdata.chelsea(),
76
+ "rocket": skdata.rocket()
77
+ }
78
+ saved_paths = []
79
+
80
+ for name, img in demo_images.items():
81
+ path = os.path.join(IMAGES_DIR, f"{name}.png")
82
+ imsave(path, img)
83
+ saved_paths.append(path)
84
+
85
+ return saved_paths
86
+
87
+ # === Streamlit UI ===
88
+ st.title("πŸ” Image Similarity Search from PDF or Custom Dataset")
89
+
90
+ # Source Selector
91
+ source = st.radio(
92
+ "Select Image Source",
93
+ ["Upload PDF", "Upload Images", "Load Demo Dataset"],
94
+ horizontal=True
95
+ )
96
 
97
+ if source == "Upload PDF":
98
+ uploaded_pdf = st.file_uploader("πŸ“€ Upload PDF", type=["pdf"])
99
+ if uploaded_pdf:
100
  with st.spinner("Extracting images..."):
101
+ images = extract_images_from_pdf(uploaded_pdf.read())
102
+ index_images(images)
103
+ st.success(f"{len(images)} images extracted and indexed.")
104
+ st.image(images, width=150)
105
+
106
+ elif source == "Upload Images":
107
+ uploaded_imgs = st.file_uploader("πŸ“€ Upload one or more images", type=["jpg", "jpeg", "png"], accept_multiple_files=True)
108
+ if uploaded_imgs:
109
+ saved_paths = []
110
+ for img in uploaded_imgs:
111
+ img_path = os.path.join(IMAGES_DIR, img.name)
112
+ with open(img_path, "wb") as f:
113
+ f.write(img.read())
114
+ saved_paths.append(img_path)
115
+
116
+ index_images(saved_paths)
117
+ st.success(f"{len(saved_paths)} images indexed.")
118
+ st.image(saved_paths, width=150)
119
+
120
+ elif source == "Load Demo Dataset":
121
+ if st.button("πŸ”„ Load Demo Images (skimage)"):
122
+ demo_paths = load_skimage_demo_images()
123
+ index_images(demo_paths)
124
+ st.success("Demo images loaded and indexed.")
125
+ st.image(demo_paths, width=150)
126
+
127
+ # Divider
128
  st.divider()
129
 
130
+ # Query Interface
131
+ st.subheader("πŸ”Ž Search for Similar Images")
132
+ query_img = st.file_uploader("Upload a query image", type=["jpg", "jpeg", "png"])
133
+ if query_img:
134
+ st.image(query_img, caption="Query Image", width=200)
135
+ with st.spinner("Searching..."):
136
+ matches = query_similar_images(query_img, top_k=5)
137
+
138
+ st.subheader("πŸ“Š Top Matches:")
139
+ for match in matches:
140
+ st.image(match, width=200, caption=os.path.basename(match))