NEXAS commited on
Commit
116caaa
ยท
verified ยท
1 Parent(s): 60c342d

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +58 -39
src/streamlit_app.py CHANGED
@@ -9,7 +9,7 @@ from skimage.io import imsave
9
  from torchvision.datasets import CIFAR10
10
  import torchvision.transforms as T
11
 
12
- # Setup cache paths
13
  HF_CACHE = os.path.join(tempfile.gettempdir(), "hf_cache")
14
  os.makedirs(HF_CACHE, exist_ok=True)
15
  os.environ["XDG_CACHE_HOME"] = HF_CACHE
@@ -19,12 +19,13 @@ from chromadb import PersistentClient
19
  from chromadb.utils.data_loaders import ImageLoader
20
  from chromadb.utils.embedding_functions import OpenCLIPEmbeddingFunction
21
 
22
- # Directories
23
  TEMP_DIR = tempfile.gettempdir()
24
  IMAGES_DIR = os.path.join(TEMP_DIR, "extracted_images")
25
  DB_PATH = os.path.join(TEMP_DIR, "image_vdb")
26
  os.makedirs(IMAGES_DIR, exist_ok=True)
27
 
 
28
  @st.cache_resource
29
  def get_chroma_collection():
30
  chroma_client = PersistentClient(path=DB_PATH)
@@ -37,7 +38,7 @@ def get_chroma_collection():
37
 
38
  image_collection = get_chroma_collection()
39
 
40
- # โ€” PDFs & Uploads โ€”
41
  def extract_images_from_pdf(pdf_bytes):
42
  pdf = fitz.open(stream=pdf_bytes, filetype="pdf")
43
  saved = []
@@ -46,20 +47,22 @@ def extract_images_from_pdf(pdf_bytes):
46
  base = pdf.extract_image(img[0])
47
  ext = base["ext"]
48
  path = os.path.join(IMAGES_DIR, f"pdf_p{i+1}_img{img[0]}.{ext}")
49
- with open(path,"wb") as f: f.write(base["image"])
 
50
  saved.append(path)
51
  return saved
52
 
 
53
  def index_images(paths):
54
  ids, uris = [], []
55
  for path in sorted(paths):
56
- if path.lower().endswith((".jpg",".jpeg",".png")):
57
  ids.append(str(uuid.uuid4()))
58
  uris.append(path)
59
  if ids:
60
  image_collection.add(ids=ids, uris=uris)
61
 
62
- # โ€” Queries โ€”
63
  def query_similar_images(image_file, top_k=5):
64
  with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as tmp:
65
  tmp.write(image_file.read())
@@ -68,16 +71,19 @@ def query_similar_images(image_file, top_k=5):
68
  os.remove(tmp.name)
69
  return res['uris'][0]
70
 
 
71
  def search_images_by_text(text, top_k=5):
72
  res = image_collection.query(query_texts=[text], n_results=top_k)
73
  return res['uris'][0]
74
 
75
- # โ€” Demo Dataset: CIFAR10 (500 images) โ€”
76
  @st.cache_resource
77
  def load_demo_cifar10(n=500):
78
  dataset = CIFAR10(root=TEMP_DIR, download=True, train=True)
79
  transform = T.ToPILImage()
80
  saved = []
 
 
81
  for i in range(min(n, len(dataset))):
82
  img, label = dataset[i]
83
  if not isinstance(img, Image.Image):
@@ -85,53 +91,66 @@ def load_demo_cifar10(n=500):
85
  path = os.path.join(IMAGES_DIR, f"cifar10_{i}_{label}.png")
86
  img.save(path)
87
  saved.append(path)
 
 
 
88
  return saved
89
 
90
- # โ€” UI Starts โ€”
91
- st.title("๐Ÿ” Image & Text Similarity Search with 500โ€‘Image Demo DB")
92
 
93
- choice = st.radio("Select data source", ["Upload PDF", "Upload Images", "Load CIFARโ€‘10 Demo"], horizontal=True)
 
94
 
95
- if choice=="Upload PDF":
96
- pdf = st.file_uploader("๐Ÿ“ค Upload PDF", type=["pdf"])
97
  if pdf:
98
- with st.spinner("Extracting..."):
99
- imgs = extract_images_from_pdf(pdf.read()); index_images(imgs)
100
- st.success(f"{len(imgs)} images indexed from PDF")
 
101
  st.image(imgs, width=120)
102
 
103
- elif choice=="Upload Images":
104
- imgs = st.file_uploader("๐Ÿ“ค Upload images", accept_multiple_files=True, type=["jpg","jpeg","png"])
105
  if imgs:
106
- paths=[]
107
- for item in imgs:
108
- p=os.path.join(IMAGES_DIR, item.name)
109
- with open(p,"wb") as f: f.write(item.read()); paths.append(p)
110
- index_images(paths)
111
- st.success(f"{len(paths)} images uploaded & indexed")
 
 
 
112
  st.image(paths, width=120)
113
 
114
- elif choice=="Load CIFARโ€‘10 Demo":
115
  if st.button("๐Ÿ”„ Load 500 CIFARโ€‘10 Images"):
116
- paths=load_demo_cifar10(500); index_images(paths)
117
- st.success("500 CIFARโ€‘10 demo images loaded and indexed")
 
 
118
  st.image(paths[:20], width=100)
119
 
 
120
  st.divider()
121
- st.subheader("๐Ÿ”Ž Image-Based Search")
122
- q = st.file_uploader("Upload a query image", type=["jpg","jpeg","png"])
123
  if q:
124
- st.image(q, caption="Query");
125
- with st.spinner("Searching..."):
126
- out = query_similar_images(q, top_k=5)
127
- st.subheader("Top Image Matches")
128
- for u in out: st.image(u, width=150)
 
129
 
130
  st.divider()
131
- st.subheader("๐Ÿ“ Text-to-Image Semantic Search")
132
- txt = st.text_input("Enter description (e.g. 'a beach'):")
133
  if txt:
134
- with st.spinner("Searching..."):
135
- out = search_images_by_text(txt, top_k=5)
136
- st.subheader("Top Semantic Matches")
137
- for u in out: st.image(u, width=150)
 
 
9
  from torchvision.datasets import CIFAR10
10
  import torchvision.transforms as T
11
 
12
+ # Set HuggingFace cache directory
13
  HF_CACHE = os.path.join(tempfile.gettempdir(), "hf_cache")
14
  os.makedirs(HF_CACHE, exist_ok=True)
15
  os.environ["XDG_CACHE_HOME"] = HF_CACHE
 
19
  from chromadb.utils.data_loaders import ImageLoader
20
  from chromadb.utils.embedding_functions import OpenCLIPEmbeddingFunction
21
 
22
+ # Paths
23
  TEMP_DIR = tempfile.gettempdir()
24
  IMAGES_DIR = os.path.join(TEMP_DIR, "extracted_images")
25
  DB_PATH = os.path.join(TEMP_DIR, "image_vdb")
26
  os.makedirs(IMAGES_DIR, exist_ok=True)
27
 
28
+ # Init ChromaDB collection
29
  @st.cache_resource
30
  def get_chroma_collection():
31
  chroma_client = PersistentClient(path=DB_PATH)
 
38
 
39
  image_collection = get_chroma_collection()
40
 
41
+ # --- Extract images from PDF ---
42
  def extract_images_from_pdf(pdf_bytes):
43
  pdf = fitz.open(stream=pdf_bytes, filetype="pdf")
44
  saved = []
 
47
  base = pdf.extract_image(img[0])
48
  ext = base["ext"]
49
  path = os.path.join(IMAGES_DIR, f"pdf_p{i+1}_img{img[0]}.{ext}")
50
+ with open(path, "wb") as f:
51
+ f.write(base["image"])
52
  saved.append(path)
53
  return saved
54
 
55
+ # --- Index images ---
56
  def index_images(paths):
57
  ids, uris = [], []
58
  for path in sorted(paths):
59
+ if path.lower().endswith((".jpg", ".jpeg", ".png", ".bmp", ".tiff", ".webp")):
60
  ids.append(str(uuid.uuid4()))
61
  uris.append(path)
62
  if ids:
63
  image_collection.add(ids=ids, uris=uris)
64
 
65
+ # --- Image-to-Image search ---
66
  def query_similar_images(image_file, top_k=5):
67
  with tempfile.NamedTemporaryFile(delete=False, suffix=".jpg") as tmp:
68
  tmp.write(image_file.read())
 
71
  os.remove(tmp.name)
72
  return res['uris'][0]
73
 
74
+ # --- Text-to-Image search ---
75
  def search_images_by_text(text, top_k=5):
76
  res = image_collection.query(query_texts=[text], n_results=top_k)
77
  return res['uris'][0]
78
 
79
+ # --- Load CIFAR-10 Demo Dataset (500 images) ---
80
  @st.cache_resource
81
  def load_demo_cifar10(n=500):
82
  dataset = CIFAR10(root=TEMP_DIR, download=True, train=True)
83
  transform = T.ToPILImage()
84
  saved = []
85
+
86
+ progress_bar = st.progress(0)
87
  for i in range(min(n, len(dataset))):
88
  img, label = dataset[i]
89
  if not isinstance(img, Image.Image):
 
91
  path = os.path.join(IMAGES_DIR, f"cifar10_{i}_{label}.png")
92
  img.save(path)
93
  saved.append(path)
94
+ if i % 10 == 0 or i == n - 1:
95
+ progress_bar.progress((i + 1) / n)
96
+
97
  return saved
98
 
99
+ # === UI START ===
100
+ st.title("๐Ÿ” Semantic Image Search App")
101
 
102
+ # Step 1: Load data
103
+ choice = st.radio("๐Ÿ“‚ Select Image Source", ["Upload PDF", "Upload Images", "Load CIFARโ€‘10 Demo"], horizontal=True)
104
 
105
+ if choice == "Upload PDF":
106
+ pdf = st.file_uploader("๐Ÿ“ค Upload PDF file", type=["pdf"])
107
  if pdf:
108
+ with st.spinner("Extracting images from PDF..."):
109
+ imgs = extract_images_from_pdf(pdf.read())
110
+ index_images(imgs)
111
+ st.success(f"โœ… Indexed {len(imgs)} images from PDF.")
112
  st.image(imgs, width=120)
113
 
114
+ elif choice == "Upload Images":
115
+ imgs = st.file_uploader("๐Ÿ“ค Upload image files", type=["jpg", "jpeg", "png", "bmp", "tiff", "webp"], accept_multiple_files=True)
116
  if imgs:
117
+ with st.spinner("Indexing uploaded images..."):
118
+ paths = []
119
+ for item in imgs:
120
+ p = os.path.join(IMAGES_DIR, item.name)
121
+ with open(p, "wb") as f:
122
+ f.write(item.read())
123
+ paths.append(p)
124
+ index_images(paths)
125
+ st.success(f"โœ… {len(paths)} images indexed.")
126
  st.image(paths, width=120)
127
 
128
+ elif choice == "Load CIFARโ€‘10 Demo":
129
  if st.button("๐Ÿ”„ Load 500 CIFARโ€‘10 Images"):
130
+ with st.spinner("Loading CIFARโ€‘10 demo dataset..."):
131
+ paths = load_demo_cifar10(500)
132
+ index_images(paths)
133
+ st.success("โœ… 500 demo images loaded and indexed.")
134
  st.image(paths[:20], width=100)
135
 
136
+ # Step 2: Search
137
  st.divider()
138
+ st.subheader("๐Ÿ–ผ๏ธ Image-to-Image Search")
139
+ q = st.file_uploader("๐Ÿ“ท Upload a query image", type=["jpg", "jpeg", "png", "bmp", "tiff", "webp"])
140
  if q:
141
+ st.image(q, caption="Query Image", width=200)
142
+ with st.spinner("Finding similar images..."):
143
+ results = query_similar_images(q, top_k=5)
144
+ st.subheader("๐Ÿ” Top Matches:")
145
+ for u in results:
146
+ st.image(u, width=150)
147
 
148
  st.divider()
149
+ st.subheader("๐Ÿ“ Text-to-Image Search")
150
+ txt = st.text_input("Describe what youโ€™re looking for (e.g., 'a beach', 'a cat', 'a red truck'):")
151
  if txt:
152
+ with st.spinner("Finding images by semantic similarity..."):
153
+ results = search_images_by_text(txt, top_k=5)
154
+ st.subheader("๐Ÿ” Semantic Matches:")
155
+ for u in results:
156
+ st.image(u, width=150)