NEXAS commited on
Commit
ae74f7a
·
verified ·
1 Parent(s): bb6635b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +271 -249
app.py CHANGED
@@ -1,250 +1,272 @@
1
- import os
2
- import streamlit as st
3
- from PIL import Image as PILImage
4
- from PIL import Image as pilImage
5
- import base64
6
- import io
7
- import chromadb
8
- from initate import process_pdf
9
- from utils.llm_ag import intiate_convo
10
- from utils.doi import process_image_and_get_description
11
-
12
- path = "mm_vdb2"
13
- client = chromadb.PersistentClient(path=path)
14
- import streamlit as st
15
- from PIL import Image as PILImage
16
-
17
- def display_images(image_collection, query_text, max_distance=None, debug=False):
18
- """
19
- Display images in a Streamlit app based on a query.
20
-
21
- Args:
22
- image_collection: The image collection object for querying.
23
- query_text (str): The text query for images.
24
- max_distance (float, optional): Maximum allowable distance for filtering.
25
- debug (bool, optional): Whether to print debug information.
26
- """
27
- results = image_collection.query(
28
- query_texts=[query_text],
29
- n_results=10,
30
- include=['uris', 'distances']
31
- )
32
-
33
- uris = results['uris'][0]
34
- distances = results['distances'][0]
35
-
36
- # Combine uris and distances, then sort by URI in ascending order
37
- sorted_results = sorted(zip(uris, distances), key=lambda x: x[0])
38
-
39
- # Filter and display images
40
- for uri, distance in sorted_results:
41
- if max_distance is None or distance <= max_distance:
42
- if debug:
43
- st.write(f"URI: {uri} - Distance: {distance}")
44
- try:
45
- img = PILImage.open(uri)
46
- st.image(img, width=300)
47
- except Exception as e:
48
- st.error(f"Error loading image {uri}: {e}")
49
- else:
50
- if debug:
51
- st.write(f"URI: {uri} - Distance: {distance} (Filtered out)")
52
-
53
-
54
-
55
- def display_videos_streamlit(video_collection, query_text, max_distance=None, max_results=5, debug=False):
56
- """
57
- Display videos in a Streamlit app based on a query.
58
-
59
- Args:
60
- video_collection: The video collection object for querying.
61
- query_text (str): The text query for videos.
62
- max_distance (float, optional): Maximum allowable distance for filtering.
63
- max_results (int, optional): Maximum number of results to display.
64
- debug (bool, optional): Whether to print debug information.
65
- """
66
- # Deduplication set
67
- displayed_videos = set()
68
-
69
- # Query the video collection with the specified text
70
- results = video_collection.query(
71
- query_texts=[query_text],
72
- n_results=max_results, # Adjust the number of results if needed
73
- include=['uris', 'distances', 'metadatas']
74
- )
75
-
76
- # Extract URIs, distances, and metadatas from the result
77
- uris = results['uris'][0]
78
- distances = results['distances'][0]
79
- metadatas = results['metadatas'][0]
80
-
81
- # Display the videos that meet the distance criteria
82
- for uri, distance, metadata in zip(uris, distances, metadatas):
83
- video_uri = metadata['video_uri']
84
-
85
- # Check if a max_distance filter is applied and the distance is within the allowed range
86
- if (max_distance is None or distance <= max_distance) and video_uri not in displayed_videos:
87
- if debug:
88
- st.write(f"URI: {uri} - Video URI: {video_uri} - Distance: {distance}")
89
- st.video(video_uri) # Display video in Streamlit
90
- displayed_videos.add(video_uri) # Add to the set to prevent duplication
91
- else:
92
- if debug:
93
- st.write(f"URI: {uri} - Video URI: {video_uri} - Distance: {distance} (Filtered out)")
94
-
95
-
96
- def image_uris(image_collection,query_text, max_distance=None, max_results=5):
97
- results = image_collection.query(
98
- query_texts=[query_text],
99
- n_results=max_results,
100
- include=['uris', 'distances']
101
- )
102
-
103
- filtered_uris = []
104
- for uri, distance in zip(results['uris'][0], results['distances'][0]):
105
- if max_distance is None or distance <= max_distance:
106
- filtered_uris.append(uri)
107
-
108
- return filtered_uris
109
-
110
- def text_uris(text_collection,query_text, max_distance=None, max_results=5):
111
- results = text_collection.query(
112
- query_texts=[query_text],
113
- n_results=max_results,
114
- include=['documents', 'distances']
115
- )
116
-
117
- filtered_texts = []
118
- for doc, distance in zip(results['documents'][0], results['distances'][0]):
119
- if max_distance is None or distance <= max_distance:
120
- filtered_texts.append(doc)
121
-
122
- return filtered_texts
123
-
124
- def frame_uris(video_collection,query_text, max_distance=None, max_results=5):
125
- results = video_collection.query(
126
- query_texts=[query_text],
127
- n_results=max_results,
128
- include=['uris', 'distances']
129
- )
130
-
131
- filtered_uris = []
132
- seen_folders = set()
133
-
134
- for uri, distance in zip(results['uris'][0], results['distances'][0]):
135
- if max_distance is None or distance <= max_distance:
136
- folder = os.path.dirname(uri)
137
- if folder not in seen_folders:
138
- filtered_uris.append(uri)
139
- seen_folders.add(folder)
140
-
141
- if len(filtered_uris) == max_results:
142
- break
143
-
144
- return filtered_uris
145
-
146
- def image_uris2(image_collection2,query_text, max_distance=None, max_results=5):
147
- results = image_collection2.query(
148
- query_texts=[query_text],
149
- n_results=max_results,
150
- include=['uris', 'distances']
151
- )
152
-
153
- filtered_uris = []
154
- for uri, distance in zip(results['uris'][0], results['distances'][0]):
155
- if max_distance is None or distance <= max_distance:
156
- filtered_uris.append(uri)
157
-
158
- return filtered_uris
159
-
160
-
161
- def format_prompt_inputs(image_collection, text_collection, video_collection, user_query):
162
- frame_candidates = frame_uris(video_collection, user_query, max_distance=1.55)
163
- image_candidates = image_uris(image_collection, user_query, max_distance=1.5)
164
- texts = text_uris(text_collection, user_query, max_distance=1.3)
165
-
166
- inputs = {"query": user_query, "texts": texts}
167
- frame = frame_candidates[0] if frame_candidates else ""
168
- inputs["frame"] = frame
169
-
170
- if image_candidates:
171
- image = image_candidates[0]
172
- with PILImage.open(image) as img:
173
- img = img.resize((img.width // 6, img.height // 6))
174
- img = img.convert("L")
175
- with io.BytesIO() as output:
176
- img.save(output, format="JPEG", quality=60)
177
- compressed_image_data = output.getvalue()
178
-
179
- inputs["image_data_1"] = base64.b64encode(compressed_image_data).decode('utf-8')
180
- else:
181
- inputs["image_data_1"] = ""
182
-
183
- return inputs
184
-
185
- def page_1():
186
- st.title("Page 1: Upload and Process PDF")
187
-
188
- uploaded_file = st.file_uploader("Upload a PDF file", type=["pdf"])
189
- if uploaded_file:
190
- pdf_path = f"/tmp/{uploaded_file.name}"
191
- with open(pdf_path, "wb") as f:
192
- f.write(uploaded_file.getbuffer())
193
-
194
- try:
195
- image_collection, text_collection, video_collection = process_pdf(pdf_path)
196
- st.session_state.image_collection = image_collection
197
- st.session_state.text_collection = text_collection
198
- st.session_state.video_collection = video_collection
199
-
200
- st.success("PDF processed successfully! Collections saved to session state.")
201
- except Exception as e:
202
- st.error(f"Error processing PDF: {e}")
203
-
204
- def page_2():
205
- st.title("Page 2: Query and Use Processed Collections")
206
-
207
- if "image_collection" in st.session_state and "text_collection" in st.session_state and "video_collection" in st.session_state:
208
- image_collection = st.session_state.image_collection
209
- text_collection = st.session_state.text_collection
210
- video_collection = st.session_state.video_collection
211
- st.success("Collections loaded successfully.")
212
-
213
- query = st.text_input("Enter your query", value="Example Query")
214
- if query:
215
- inputs = format_prompt_inputs(image_collection, text_collection, video_collection, query)
216
- texts = inputs["texts"]
217
- image_data_1 = inputs["image_data_1"]
218
-
219
- if image_data_1:
220
- image_data_1 = process_image_and_get_description(image_data_1)
221
-
222
- response = intiate_convo(query, image_data_1, texts)
223
- st.write("Response:", response)
224
-
225
- st.markdown("### Images")
226
- display_images(image_collection, query, max_distance=1.55, debug=True)
227
-
228
- st.markdown("### Videos")
229
- frame = inputs["frame"]
230
- if frame:
231
- video_path = f"StockVideos-CC0/{os.path.basename(frame).split('/')[0]}.mp4"
232
- if os.path.exists(video_path):
233
- st.video(video_path)
234
- else:
235
- st.write("No related videos found.")
236
- else:
237
- st.error("Collections not found in session state. Please process the PDF on Page 1.")
238
-
239
- # --- Navigation ---
240
-
241
- PAGES = {
242
- "Upload and Process PDF": page_1,
243
- "Query and Use Processed Collections": page_2
244
- }
245
-
246
- # Select page
247
- selected_page = st.sidebar.selectbox("Choose a page", options=list(PAGES.keys()))
248
-
249
- # Render selected page
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
250
  PAGES[selected_page]()
 
1
+ import os
2
+ import streamlit as st
3
+ from PIL import Image as PILImage
4
+ from PIL import Image as pilImage
5
+ import base64
6
+ import io
7
+ import chromadb
8
+ from initate import process_pdf
9
+ from utils.llm_ag import intiate_convo
10
+ from utils.doi import process_image_and_get_description
11
+
12
+ path = "mm_vdb2"
13
+ client = chromadb.PersistentClient(path=path)
14
+ import streamlit as st
15
+ from PIL import Image as PILImage
16
+
17
+ def display_images(image_collection, query_text, max_distance=None, debug=False):
18
+ """
19
+ Display images in a Streamlit app based on a query.
20
+
21
+ Args:
22
+ image_collection: The image collection object for querying.
23
+ query_text (str): The text query for images.
24
+ max_distance (float, optional): Maximum allowable distance for filtering.
25
+ debug (bool, optional): Whether to print debug information.
26
+ """
27
+ results = image_collection.query(
28
+ query_texts=[query_text],
29
+ n_results=10,
30
+ include=['uris', 'distances']
31
+ )
32
+
33
+ uris = results['uris'][0]
34
+ distances = results['distances'][0]
35
+
36
+ # Combine uris and distances, then sort by URI in ascending order
37
+ sorted_results = sorted(zip(uris, distances), key=lambda x: x[0])
38
+
39
+ # Filter and display images
40
+ for uri, distance in sorted_results:
41
+ if max_distance is None or distance <= max_distance:
42
+ if debug:
43
+ st.write(f"URI: {uri} - Distance: {distance}")
44
+ try:
45
+ img = PILImage.open(uri)
46
+ st.image(img, width=300)
47
+ except Exception as e:
48
+ st.error(f"Error loading image {uri}: {e}")
49
+ else:
50
+ if debug:
51
+ st.write(f"URI: {uri} - Distance: {distance} (Filtered out)")
52
+
53
+
54
+
55
+ def display_videos_streamlit(video_collection, query_text, max_distance=None, max_results=5, debug=False):
56
+ """
57
+ Display videos in a Streamlit app based on a query.
58
+
59
+ Args:
60
+ video_collection: The video collection object for querying.
61
+ query_text (str): The text query for videos.
62
+ max_distance (float, optional): Maximum allowable distance for filtering.
63
+ max_results (int, optional): Maximum number of results to display.
64
+ debug (bool, optional): Whether to print debug information.
65
+ """
66
+ # Deduplication set
67
+ displayed_videos = set()
68
+
69
+ # Query the video collection with the specified text
70
+ results = video_collection.query(
71
+ query_texts=[query_text],
72
+ n_results=max_results, # Adjust the number of results if needed
73
+ include=['uris', 'distances', 'metadatas']
74
+ )
75
+
76
+ # Extract URIs, distances, and metadatas from the result
77
+ uris = results['uris'][0]
78
+ distances = results['distances'][0]
79
+ metadatas = results['metadatas'][0]
80
+
81
+ # Display the videos that meet the distance criteria
82
+ for uri, distance, metadata in zip(uris, distances, metadatas):
83
+ video_uri = metadata['video_uri']
84
+
85
+ # Check if a max_distance filter is applied and the distance is within the allowed range
86
+ if (max_distance is None or distance <= max_distance) and video_uri not in displayed_videos:
87
+ if debug:
88
+ st.write(f"URI: {uri} - Video URI: {video_uri} - Distance: {distance}")
89
+ st.video(video_uri) # Display video in Streamlit
90
+ displayed_videos.add(video_uri) # Add to the set to prevent duplication
91
+ else:
92
+ if debug:
93
+ st.write(f"URI: {uri} - Video URI: {video_uri} - Distance: {distance} (Filtered out)")
94
+
95
+
96
+ def image_uris(image_collection,query_text, max_distance=None, max_results=5):
97
+ results = image_collection.query(
98
+ query_texts=[query_text],
99
+ n_results=max_results,
100
+ include=['uris', 'distances']
101
+ )
102
+
103
+ filtered_uris = []
104
+ for uri, distance in zip(results['uris'][0], results['distances'][0]):
105
+ if max_distance is None or distance <= max_distance:
106
+ filtered_uris.append(uri)
107
+
108
+ return filtered_uris
109
+
110
+ def text_uris(text_collection,query_text, max_distance=None, max_results=5):
111
+ results = text_collection.query(
112
+ query_texts=[query_text],
113
+ n_results=max_results,
114
+ include=['documents', 'distances']
115
+ )
116
+
117
+ filtered_texts = []
118
+ for doc, distance in zip(results['documents'][0], results['distances'][0]):
119
+ if max_distance is None or distance <= max_distance:
120
+ filtered_texts.append(doc)
121
+
122
+ return filtered_texts
123
+
124
+ def frame_uris(video_collection,query_text, max_distance=None, max_results=5):
125
+ results = video_collection.query(
126
+ query_texts=[query_text],
127
+ n_results=max_results,
128
+ include=['uris', 'distances']
129
+ )
130
+
131
+ filtered_uris = []
132
+ seen_folders = set()
133
+
134
+ for uri, distance in zip(results['uris'][0], results['distances'][0]):
135
+ if max_distance is None or distance <= max_distance:
136
+ folder = os.path.dirname(uri)
137
+ if folder not in seen_folders:
138
+ filtered_uris.append(uri)
139
+ seen_folders.add(folder)
140
+
141
+ if len(filtered_uris) == max_results:
142
+ break
143
+
144
+ return filtered_uris
145
+
146
+ def image_uris2(image_collection2,query_text, max_distance=None, max_results=5):
147
+ results = image_collection2.query(
148
+ query_texts=[query_text],
149
+ n_results=max_results,
150
+ include=['uris', 'distances']
151
+ )
152
+
153
+ filtered_uris = []
154
+ for uri, distance in zip(results['uris'][0], results['distances'][0]):
155
+ if max_distance is None or distance <= max_distance:
156
+ filtered_uris.append(uri)
157
+
158
+ return filtered_uris
159
+
160
+
161
+ def format_prompt_inputs(image_collection, text_collection, video_collection, user_query):
162
+ frame_candidates = frame_uris(video_collection, user_query, max_distance=1.55)
163
+ image_candidates = image_uris(image_collection, user_query, max_distance=1.5)
164
+ texts = text_uris(text_collection, user_query, max_distance=1.3)
165
+
166
+ inputs = {"query": user_query, "texts": texts}
167
+ frame = frame_candidates[0] if frame_candidates else ""
168
+ inputs["frame"] = frame
169
+
170
+ if image_candidates:
171
+ image = image_candidates[0]
172
+ with PILImage.open(image) as img:
173
+ img = img.resize((img.width // 6, img.height // 6))
174
+ img = img.convert("L")
175
+ with io.BytesIO() as output:
176
+ img.save(output, format="JPEG", quality=60)
177
+ compressed_image_data = output.getvalue()
178
+
179
+ inputs["image_data_1"] = base64.b64encode(compressed_image_data).decode('utf-8')
180
+ else:
181
+ inputs["image_data_1"] = ""
182
+
183
+ return inputs
184
+
185
+ def page_1():
186
+ st.title("Page 1: Upload and Process PDF")
187
+
188
+ # File uploader for PDF
189
+ uploaded_file = st.file_uploader("Upload a PDF file", type=["pdf"])
190
+
191
+ # Button to trigger processing
192
+ if uploaded_file and st.button("Process PDF"):
193
+ pdf_path = f"/tmp/{uploaded_file.name}"
194
+ with open(pdf_path, "wb") as f:
195
+ f.write(uploaded_file.getbuffer())
196
+
197
+ # Progress bar
198
+ progress_bar = st.progress(0)
199
+ status_text = st.empty()
200
+
201
+ try:
202
+ progress_bar.progress(10)
203
+ status_text.text("Initializing processing...")
204
+
205
+ # Simulating progress during processing
206
+ for progress in range(10, 100, 30):
207
+ st.time.sleep(0.5) # Simulate processing delay
208
+ progress_bar.progress(progress)
209
+ status_text.text(f"Processing... {progress}%")
210
+
211
+ # Process the PDF and save collections to session state
212
+ image_collection, text_collection, video_collection = process_pdf(pdf_path)
213
+ st.session_state.image_collection = image_collection
214
+ st.session_state.text_collection = text_collection
215
+ st.session_state.video_collection = video_collection
216
+
217
+ progress_bar.progress(100)
218
+ status_text.text("Processing completed successfully!")
219
+ st.success("PDF processed successfully! Collections saved to session state.")
220
+ except Exception as e:
221
+ progress_bar.progress(0)
222
+ status_text.text("")
223
+ st.error(f"Error processing PDF: {e}")
224
+
225
+
226
+ def page_2():
227
+ st.title("Page 2: Query and Use Processed Collections")
228
+
229
+ if "image_collection" in st.session_state and "text_collection" in st.session_state and "video_collection" in st.session_state:
230
+ image_collection = st.session_state.image_collection
231
+ text_collection = st.session_state.text_collection
232
+ video_collection = st.session_state.video_collection
233
+ st.success("Collections loaded successfully.")
234
+
235
+ query = st.text_input("Enter your query", value="Example Query")
236
+ if query:
237
+ inputs = format_prompt_inputs(image_collection, text_collection, video_collection, query)
238
+ texts = inputs["texts"]
239
+ image_data_1 = inputs["image_data_1"]
240
+
241
+ if image_data_1:
242
+ image_data_1 = process_image_and_get_description(image_data_1)
243
+
244
+ response = intiate_convo(query, image_data_1, texts)
245
+ st.write("Response:", response)
246
+
247
+ st.markdown("### Images")
248
+ display_images(image_collection, query, max_distance=1.55, debug=True)
249
+
250
+ st.markdown("### Videos")
251
+ frame = inputs["frame"]
252
+ if frame:
253
+ video_path = f"StockVideos-CC0/{os.path.basename(frame).split('/')[0]}.mp4"
254
+ if os.path.exists(video_path):
255
+ st.video(video_path)
256
+ else:
257
+ st.write("No related videos found.")
258
+ else:
259
+ st.error("Collections not found in session state. Please process the PDF on Page 1.")
260
+
261
+ # --- Navigation ---
262
+
263
+ PAGES = {
264
+ "Upload and Process PDF": page_1,
265
+ "Query and Use Processed Collections": page_2
266
+ }
267
+
268
+ # Select page
269
+ selected_page = st.sidebar.selectbox("Choose a page", options=list(PAGES.keys()))
270
+
271
+ # Render selected page
272
  PAGES[selected_page]()