NEXAS commited on
Commit
3d041f0
·
verified ·
1 Parent(s): 85eeb6f

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +102 -60
src/streamlit_app.py CHANGED
@@ -8,12 +8,19 @@ import numpy as np
8
  from skimage.io import imsave
9
  from torchvision.datasets import CIFAR10
10
  import torchvision.transforms as T
 
 
 
 
 
11
 
12
  # Set HuggingFace cache directory
13
  HF_CACHE = os.path.join(tempfile.gettempdir(), "hf_cache")
14
  os.makedirs(HF_CACHE, exist_ok=True)
15
  os.environ["XDG_CACHE_HOME"] = HF_CACHE
16
  os.environ["HF_HOME"] = HF_CACHE
 
 
17
 
18
  from chromadb import PersistentClient
19
  from chromadb.utils.data_loaders import ImageLoader
@@ -24,83 +31,112 @@ TEMP_DIR = tempfile.gettempdir()
24
  IMAGES_DIR = os.path.join(TEMP_DIR, "extracted_images")
25
  DB_PATH = os.path.join(TEMP_DIR, "image_vdb")
26
  os.makedirs(IMAGES_DIR, exist_ok=True)
 
27
 
28
  # Init ChromaDB collection
29
  @st.cache_resource
30
  def get_chroma_collection():
31
- chroma_client = PersistentClient(path=DB_PATH)
32
- image_loader = ImageLoader()
33
- embedding_fn = OpenCLIPEmbeddingFunction()
34
- collection = chroma_client.get_or_create_collection(
35
- name="image", embedding_function=embedding_fn, data_loader=image_loader
36
- )
37
- return collection
 
 
 
 
 
38
 
39
  image_collection = get_chroma_collection()
 
 
40
 
41
  # --- Extract images from PDF ---
42
  def extract_images_from_pdf(pdf_bytes):
43
- pdf = fitz.open(stream=pdf_bytes, filetype="pdf")
44
- saved = []
45
- for i in range(len(pdf)):
46
- for img in pdf.load_page(i).get_images(full=True):
47
- base = pdf.extract_image(img[0])
48
- ext = base["ext"]
49
- path = os.path.join(IMAGES_DIR, f"pdf_p{i+1}_img{img[0]}.{ext}")
50
- with open(path, "wb") as f:
51
- f.write(base["image"])
52
- saved.append(path)
53
- return saved
 
 
 
 
 
54
 
55
  # --- Index images ---
56
  def index_images(paths):
57
- ids, uris = [], []
58
- for path in sorted(paths):
59
- if path.lower().endswith((".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".webp")):
60
- ids.append(str(uuid.uuid4()))
61
- uris.append(path)
62
- if ids:
63
- image_collection.add(ids=ids, uris=uris)
 
 
 
 
64
 
65
  # --- Image-to-Image search ---
66
  def query_similar_images(image_file, top_k=5):
67
- with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as tmp:
68
- tmp.write(image_file.read())
69
- tmp.flush()
70
- res = image_collection.query(query_uris=[tmp.name], n_results=top_k)
71
- os.remove(tmp.name)
72
- # Safe check for results
73
- if not res or 'uris' not in res or not res['uris'] or not res['uris'][0]:
 
 
 
 
 
74
  return []
75
- return res['uris'][0]
76
 
77
  # --- Text-to-Image search ---
78
  def search_images_by_text(text, top_k=5):
79
- res = image_collection.query(query_texts=[text], n_results=top_k)
80
- # Safe check for results
81
- if not res or 'uris' not in res or not res['uris'] or not res['uris'][0]:
 
 
 
 
 
82
  return []
83
- return res['uris'][0]
84
 
85
  # --- Load CIFAR-10 Demo Dataset (500 images) ---
86
  @st.cache_resource
87
  def load_demo_cifar10(n=500):
88
- dataset = CIFAR10(root=TEMP_DIR, download=True, train=True)
89
- transform = T.ToPILImage()
90
- saved = []
91
-
92
- progress_bar = st.progress(0)
93
- for i in range(min(n, len(dataset))):
94
- img, label = dataset[i]
95
- if not isinstance(img, Image.Image):
96
- img = transform(img)
97
- path = os.path.join(IMAGES_DIR, f"cifar10_{i}_{label}.png")
98
- img.save(path)
99
- saved.append(path)
100
- if i % 10 == 0 or i == n - 1:
101
- progress_bar.progress((i + 1) / n)
102
-
103
- return saved
 
 
 
 
104
 
105
  # === UI START ===
106
  st.title("🔍 Semantic Image Search App")
@@ -113,9 +149,12 @@ if choice == "Upload PDF":
113
  if pdf:
114
  with st.spinner("Extracting images from PDF..."):
115
  imgs = extract_images_from_pdf(pdf.read())
116
- index_images(imgs)
117
- st.success(f"✅ Indexed {len(imgs)} images from PDF.")
118
- st.image(imgs, width=120)
 
 
 
119
 
120
  elif choice == "Upload Images":
121
  imgs = st.file_uploader("📤 Upload image files", type=["jpg", "jpeg", "png", "bmp", "tiff", "webp"], accept_multiple_files=True)
@@ -135,9 +174,12 @@ elif choice == "Load CIFAR‑10 Demo":
135
  if st.button("🔄 Load 500 CIFAR‑10 Images"):
136
  with st.spinner("Loading CIFAR‑10 demo dataset..."):
137
  paths = load_demo_cifar10(500)
138
- index_images(paths)
139
- st.success("✅ 500 demo images loaded and indexed.")
140
- st.image(paths[:20], width=100)
 
 
 
141
 
142
  # Step 2: Search
143
  st.divider()
@@ -165,4 +207,4 @@ if txt:
165
  else:
166
  st.subheader("🔍 Semantic Matches:")
167
  for u in results:
168
- st.image(u, width=150)
 
8
  from skimage.io import imsave
9
  from torchvision.datasets import CIFAR10
10
  import torchvision.transforms as T
11
+ import logging
12
+
13
+ # Set up logging
14
+ logging.basicConfig(level=logging.DEBUG)
15
+ logger = logging.getLogger(__name__)
16
 
17
  # Set HuggingFace cache directory
18
  HF_CACHE = os.path.join(tempfile.gettempdir(), "hf_cache")
19
  os.makedirs(HF_CACHE, exist_ok=True)
20
  os.environ["XDG_CACHE_HOME"] = HF_CACHE
21
  os.environ["HF_HOME"] = HF_CACHE
22
+ # Add HuggingFace token if needed
23
+ # os.environ["HF_TOKEN"] = "your-huggingface-api-token"
24
 
25
  from chromadb import PersistentClient
26
  from chromadb.utils.data_loaders import ImageLoader
 
31
  IMAGES_DIR = os.path.join(TEMP_DIR, "extracted_images")
32
  DB_PATH = os.path.join(TEMP_DIR, "image_vdb")
33
  os.makedirs(IMAGES_DIR, exist_ok=True)
34
+ os.makedirs(DB_PATH, exist_ok=True)
35
 
36
  # Init ChromaDB collection
37
  @st.cache_resource
38
  def get_chroma_collection():
39
+ try:
40
+ chroma_client = PersistentClient(path=DB_PATH)
41
+ image_loader = ImageLoader()
42
+ embedding_fn = OpenCLIPEmbeddingFunction()
43
+ collection = chroma_client.get_or_create_collection(
44
+ name="image", embedding_function=embedding_fn, data_loader=image_loader
45
+ )
46
+ return collection
47
+ except Exception as e:
48
+ logger.error(f"Error initializing ChromaDB: {e}")
49
+ st.error(f"Failed to initialize ChromaDB: {e}")
50
+ return None
51
 
52
  image_collection = get_chroma_collection()
53
+ if image_collection is None:
54
+ st.stop()
55
 
56
  # --- Extract images from PDF ---
57
  def extract_images_from_pdf(pdf_bytes):
58
+ try:
59
+ pdf = fitz.open(stream=pdf_bytes, filetype="pdf")
60
+ saved = []
61
+ for i in range(len(pdf)):
62
+ for img in pdf.load_page(i).get_images(full=True):
63
+ base = pdf.extract_image(img[0])
64
+ ext = base["ext"]
65
+ path = os.path.join(IMAGES_DIR, f"pdf_p{i+1}_img{img[0]}.{ext}")
66
+ with open(path, "wb") as f:
67
+ f.write(base["image"])
68
+ saved.append(path)
69
+ return saved
70
+ except Exception as e:
71
+ logger.error(f"Error extracting images from PDF: {e}")
72
+ st.error(f"Failed to extract images: {e}")
73
+ return []
74
 
75
  # --- Index images ---
76
  def index_images(paths):
77
+ try:
78
+ ids, uris = [], []
79
+ for path in sorted(paths):
80
+ if path.lower().endswith((".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".webp")):
81
+ ids.append(str(uuid.uuid4()))
82
+ uris.append(path)
83
+ if ids:
84
+ image_collection.add(ids=ids, uris=uris)
85
+ except Exception as e:
86
+ logger.error(f"Error indexing images: {e}")
87
+ st.error(f"Failed to index images: {e}")
88
 
89
  # --- Image-to-Image search ---
90
  def query_similar_images(image_file, top_k=5):
91
+ try:
92
+ with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as tmp:
93
+ tmp.write(image_file.read())
94
+ tmp.flush()
95
+ res = image_collection.query(query_uris=[tmp.name], n_results=top_k)
96
+ os.remove(tmp.name)
97
+ if not res or 'uris' not in res or not res['uris'] or not res['uris'][0]:
98
+ return []
99
+ return res['uris'][0]
100
+ except Exception as e:
101
+ logger.error(f"Error in image-to-image search: {e}")
102
+ st.error(f"Failed to perform image search: {e}")
103
  return []
 
104
 
105
  # --- Text-to-Image search ---
106
  def search_images_by_text(text, top_k=5):
107
+ try:
108
+ res = image_collection.query(query_texts=[text], n_results=top_k)
109
+ if not res or 'uris' not in res or not res['uris'] or not res['uris'][0]:
110
+ return []
111
+ return res['uris'][0]
112
+ except Exception as e:
113
+ logger.error(f"Error in text-to-image search: {e}")
114
+ st.error(f"Failed to perform text search: {e}")
115
  return []
 
116
 
117
  # --- Load CIFAR-10 Demo Dataset (500 images) ---
118
  @st.cache_resource
119
  def load_demo_cifar10(n=500):
120
+ try:
121
+ dataset = CIFAR10(root=TEMP_DIR, download=True, train=True)
122
+ transform = T.ToPILImage()
123
+ saved = []
124
+
125
+ progress_bar = st.progress(0)
126
+ for i in range(min(n, len(dataset))):
127
+ img, label = dataset[i]
128
+ if not isinstance(img, Image.Image):
129
+ img = transform(img)
130
+ path = os.path.join(IMAGES_DIR, f"cifar10_{i}_{label}.png")
131
+ img.save(path)
132
+ saved.append(path)
133
+ if i % 10 == 0 or i == n - 1:
134
+ progress_bar.progress((i + 1) / n)
135
+ return saved
136
+ except Exception as e:
137
+ logger.error(f"Error loading CIFAR-10 dataset: {e}")
138
+ st.error(f"Failed to load CIFAR-10 dataset: {e}")
139
+ return []
140
 
141
  # === UI START ===
142
  st.title("🔍 Semantic Image Search App")
 
149
  if pdf:
150
  with st.spinner("Extracting images from PDF..."):
151
  imgs = extract_images_from_pdf(pdf.read())
152
+ if imgs:
153
+ index_images(imgs)
154
+ st.success(f"✅ Indexed {len(imgs)} images from PDF.")
155
+ st.image(imgs, width=120)
156
+ else:
157
+ st.warning("No images extracted from PDF.")
158
 
159
  elif choice == "Upload Images":
160
  imgs = st.file_uploader("📤 Upload image files", type=["jpg", "jpeg", "png", "bmp", "tiff", "webp"], accept_multiple_files=True)
 
174
  if st.button("🔄 Load 500 CIFAR‑10 Images"):
175
  with st.spinner("Loading CIFAR‑10 demo dataset..."):
176
  paths = load_demo_cifar10(500)
177
+ if paths:
178
+ index_images(paths)
179
+ st.success("✅ 500 demo images loaded and indexed.")
180
+ st.image(paths[:20], width=100)
181
+ else:
182
+ st.warning("Failed to load CIFAR-10 images.")
183
 
184
  # Step 2: Search
185
  st.divider()
 
207
  else:
208
  st.subheader("🔍 Semantic Matches:")
209
  for u in results:
210
+ st.image(u, width=150)