NEXAS commited on
Commit
ee979c8
Β·
verified Β·
1 Parent(s): 228d5b0

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +29 -20
src/streamlit_app.py CHANGED
@@ -10,20 +10,25 @@ from skimage import data as skdata
10
  from skimage.io import imsave
11
  import uuid
12
 
13
- # Paths
14
- DB_PATH = './data/image_vdb'
15
- IMAGES_DIR = './data/extracted_images'
 
16
  os.makedirs(IMAGES_DIR, exist_ok=True)
17
 
18
- # Init ChromaDB
19
- chroma_client = PersistentClient(path=DB_PATH)
20
- image_loader = ImageLoader()
21
- embedding_fn = OpenCLIPEmbeddingFunction()
22
- image_collection = chroma_client.get_or_create_collection(
23
- name="image", embedding_function=embedding_fn, data_loader=image_loader
24
- )
 
 
 
 
25
 
26
- # === Image Handling ===
27
  def extract_images_from_pdf(pdf_bytes):
28
  pdf = fitz.open(stream=pdf_bytes, filetype="pdf")
29
  saved_images = []
@@ -47,26 +52,31 @@ def extract_images_from_pdf(pdf_bytes):
47
 
48
  return saved_images
49
 
 
50
  def index_images(image_paths):
51
  ids = []
52
  uris = []
53
- for i, path in enumerate(sorted(image_paths)):
54
- if path.endswith((".png", ".jpeg", ".jpg")):
55
  ids.append(str(uuid.uuid4()))
56
  uris.append(path)
57
 
58
  if ids:
59
  image_collection.add(ids=ids, uris=uris)
60
 
 
61
  def query_similar_images(image_file, top_k=5):
62
  with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as tmp:
63
  tmp.write(image_file.read())
64
  tmp_path = tmp.name
65
 
66
- results = image_collection.query(query_uris=[tmp_path], n_results=top_k)
67
- os.remove(tmp_path)
68
- return results['uris'][0]
 
 
69
 
 
70
  def load_skimage_demo_images():
71
  demo_images = {
72
  "astronaut": skdata.astronaut(),
@@ -87,7 +97,6 @@ def load_skimage_demo_images():
87
  # === Streamlit UI ===
88
  st.title("πŸ” Image Similarity Search from PDF or Custom Dataset")
89
 
90
- # Source Selector
91
  source = st.radio(
92
  "Select Image Source",
93
  ["Upload PDF", "Upload Images", "Load Demo Dataset"],
@@ -104,7 +113,9 @@ if source == "Upload PDF":
104
  st.image(images, width=150)
105
 
106
  elif source == "Upload Images":
107
- uploaded_imgs = st.file_uploader("πŸ“€ Upload one or more images", type=["jpg", "jpeg", "png"], accept_multiple_files=True)
 
 
108
  if uploaded_imgs:
109
  saved_paths = []
110
  for img in uploaded_imgs:
@@ -124,10 +135,8 @@ elif source == "Load Demo Dataset":
124
  st.success("Demo images loaded and indexed.")
125
  st.image(demo_paths, width=150)
126
 
127
- # Divider
128
  st.divider()
129
 
130
- # Query Interface
131
  st.subheader("πŸ”Ž Search for Similar Images")
132
  query_img = st.file_uploader("Upload a query image", type=["jpg", "jpeg", "png"])
133
  if query_img:
 
10
  from skimage.io import imsave
11
  import uuid
12
 
13
+ # Use safe temp directories for Streamlit or restricted environments
14
+ TEMP_DIR = tempfile.gettempdir()
15
+ IMAGES_DIR = os.path.join(TEMP_DIR, "extracted_images")
16
+ DB_PATH = os.path.join(TEMP_DIR, "image_vdb")
17
  os.makedirs(IMAGES_DIR, exist_ok=True)
18
 
19
+ @st.cache_resource
20
+ def get_chroma_collection():
21
+ chroma_client = PersistentClient(path=DB_PATH)
22
+ image_loader = ImageLoader()
23
+ embedding_fn = OpenCLIPEmbeddingFunction()
24
+ collection = chroma_client.get_or_create_collection(
25
+ name="image", embedding_function=embedding_fn, data_loader=image_loader
26
+ )
27
+ return collection
28
+
29
+ image_collection = get_chroma_collection()
30
 
31
+ # === Image Extraction ===
32
  def extract_images_from_pdf(pdf_bytes):
33
  pdf = fitz.open(stream=pdf_bytes, filetype="pdf")
34
  saved_images = []
 
52
 
53
  return saved_images
54
 
55
+ # === Indexing ===
56
  def index_images(image_paths):
57
  ids = []
58
  uris = []
59
+ for path in sorted(image_paths):
60
+ if path.lower().endswith((".png", ".jpeg", ".jpg")):
61
  ids.append(str(uuid.uuid4()))
62
  uris.append(path)
63
 
64
  if ids:
65
  image_collection.add(ids=ids, uris=uris)
66
 
67
+ # === Querying ===
68
  def query_similar_images(image_file, top_k=5):
69
  with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as tmp:
70
  tmp.write(image_file.read())
71
  tmp_path = tmp.name
72
 
73
+ try:
74
+ results = image_collection.query(query_uris=[tmp_path], n_results=top_k)
75
+ return results['uris'][0]
76
+ finally:
77
+ os.remove(tmp_path)
78
 
79
+ # === Demo images ===
80
  def load_skimage_demo_images():
81
  demo_images = {
82
  "astronaut": skdata.astronaut(),
 
97
  # === Streamlit UI ===
98
  st.title("πŸ” Image Similarity Search from PDF or Custom Dataset")
99
 
 
100
  source = st.radio(
101
  "Select Image Source",
102
  ["Upload PDF", "Upload Images", "Load Demo Dataset"],
 
113
  st.image(images, width=150)
114
 
115
  elif source == "Upload Images":
116
+ uploaded_imgs = st.file_uploader(
117
+ "πŸ“€ Upload one or more images", type=["jpg", "jpeg", "png"], accept_multiple_files=True
118
+ )
119
  if uploaded_imgs:
120
  saved_paths = []
121
  for img in uploaded_imgs:
 
135
  st.success("Demo images loaded and indexed.")
136
  st.image(demo_paths, width=150)
137
 
 
138
  st.divider()
139
 
 
140
  st.subheader("πŸ”Ž Search for Similar Images")
141
  query_img = st.file_uploader("Upload a query image", type=["jpg", "jpeg", "png"])
142
  if query_img: