File size: 9,924 Bytes
ae74f7a 03d82bf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 |
import os
import streamlit as st
from PIL import Image as PILImage
from PIL import Image as pilImage
import base64
import io
import chromadb
from initate import process_pdf
from utils.llm_ag import intiate_convo
from utils.doi import process_image_and_get_description
path = "mm_vdb2"
client = chromadb.PersistentClient(path=path)
import streamlit as st
from PIL import Image as PILImage
def display_images(image_collection, query_text, max_distance=None, debug=False):
"""
Display images in a Streamlit app based on a query.
Args:
image_collection: The image collection object for querying.
query_text (str): The text query for images.
max_distance (float, optional): Maximum allowable distance for filtering.
debug (bool, optional): Whether to print debug information.
"""
results = image_collection.query(
query_texts=[query_text],
n_results=10,
include=['uris', 'distances']
)
uris = results['uris'][0]
distances = results['distances'][0]
# Combine uris and distances, then sort by URI in ascending order
sorted_results = sorted(zip(uris, distances), key=lambda x: x[0])
# Filter and display images
for uri, distance in sorted_results:
if max_distance is None or distance <= max_distance:
if debug:
st.write(f"URI: {uri} - Distance: {distance}")
try:
img = PILImage.open(uri)
st.image(img, width=300)
except Exception as e:
st.error(f"Error loading image {uri}: {e}")
else:
if debug:
st.write(f"URI: {uri} - Distance: {distance} (Filtered out)")
def display_videos_streamlit(video_collection, query_text, max_distance=None, max_results=5, debug=False):
"""
Display videos in a Streamlit app based on a query.
Args:
video_collection: The video collection object for querying.
query_text (str): The text query for videos.
max_distance (float, optional): Maximum allowable distance for filtering.
max_results (int, optional): Maximum number of results to display.
debug (bool, optional): Whether to print debug information.
"""
# Deduplication set
displayed_videos = set()
# Query the video collection with the specified text
results = video_collection.query(
query_texts=[query_text],
n_results=max_results, # Adjust the number of results if needed
include=['uris', 'distances', 'metadatas']
)
# Extract URIs, distances, and metadatas from the result
uris = results['uris'][0]
distances = results['distances'][0]
metadatas = results['metadatas'][0]
# Display the videos that meet the distance criteria
for uri, distance, metadata in zip(uris, distances, metadatas):
video_uri = metadata['video_uri']
# Check if a max_distance filter is applied and the distance is within the allowed range
if (max_distance is None or distance <= max_distance) and video_uri not in displayed_videos:
if debug:
st.write(f"URI: {uri} - Video URI: {video_uri} - Distance: {distance}")
st.video(video_uri) # Display video in Streamlit
displayed_videos.add(video_uri) # Add to the set to prevent duplication
else:
if debug:
st.write(f"URI: {uri} - Video URI: {video_uri} - Distance: {distance} (Filtered out)")
def image_uris(image_collection,query_text, max_distance=None, max_results=5):
results = image_collection.query(
query_texts=[query_text],
n_results=max_results,
include=['uris', 'distances']
)
filtered_uris = []
for uri, distance in zip(results['uris'][0], results['distances'][0]):
if max_distance is None or distance <= max_distance:
filtered_uris.append(uri)
return filtered_uris
def text_uris(text_collection,query_text, max_distance=None, max_results=5):
results = text_collection.query(
query_texts=[query_text],
n_results=max_results,
include=['documents', 'distances']
)
filtered_texts = []
for doc, distance in zip(results['documents'][0], results['distances'][0]):
if max_distance is None or distance <= max_distance:
filtered_texts.append(doc)
return filtered_texts
def frame_uris(video_collection,query_text, max_distance=None, max_results=5):
results = video_collection.query(
query_texts=[query_text],
n_results=max_results,
include=['uris', 'distances']
)
filtered_uris = []
seen_folders = set()
for uri, distance in zip(results['uris'][0], results['distances'][0]):
if max_distance is None or distance <= max_distance:
folder = os.path.dirname(uri)
if folder not in seen_folders:
filtered_uris.append(uri)
seen_folders.add(folder)
if len(filtered_uris) == max_results:
break
return filtered_uris
def image_uris2(image_collection2,query_text, max_distance=None, max_results=5):
results = image_collection2.query(
query_texts=[query_text],
n_results=max_results,
include=['uris', 'distances']
)
filtered_uris = []
for uri, distance in zip(results['uris'][0], results['distances'][0]):
if max_distance is None or distance <= max_distance:
filtered_uris.append(uri)
return filtered_uris
def format_prompt_inputs(image_collection, text_collection, video_collection, user_query):
frame_candidates = frame_uris(video_collection, user_query, max_distance=1.55)
image_candidates = image_uris(image_collection, user_query, max_distance=1.5)
texts = text_uris(text_collection, user_query, max_distance=1.3)
inputs = {"query": user_query, "texts": texts}
frame = frame_candidates[0] if frame_candidates else ""
inputs["frame"] = frame
if image_candidates:
image = image_candidates[0]
with PILImage.open(image) as img:
img = img.resize((img.width // 6, img.height // 6))
img = img.convert("L")
with io.BytesIO() as output:
img.save(output, format="JPEG", quality=60)
compressed_image_data = output.getvalue()
inputs["image_data_1"] = base64.b64encode(compressed_image_data).decode('utf-8')
else:
inputs["image_data_1"] = ""
return inputs
def page_1():
st.title("Page 1: Upload and Process PDF")
# File uploader for PDF
uploaded_file = st.file_uploader("Upload a PDF file", type=["pdf"])
# Button to trigger processing
if uploaded_file and st.button("Process PDF"):
pdf_path = f"/tmp/{uploaded_file.name}"
with open(pdf_path, "wb") as f:
f.write(uploaded_file.getbuffer())
# Progress bar
progress_bar = st.progress(0)
status_text = st.empty()
try:
progress_bar.progress(10)
status_text.text("Initializing processing...")
# Simulating progress during processing
for progress in range(10, 100, 30):
st.time.sleep(0.5) # Simulate processing delay
progress_bar.progress(progress)
status_text.text(f"Processing... {progress}%")
# Process the PDF and save collections to session state
image_collection, text_collection, video_collection = process_pdf(pdf_path)
st.session_state.image_collection = image_collection
st.session_state.text_collection = text_collection
st.session_state.video_collection = video_collection
progress_bar.progress(100)
status_text.text("Processing completed successfully!")
st.success("PDF processed successfully! Collections saved to session state.")
except Exception as e:
progress_bar.progress(0)
status_text.text("")
st.error(f"Error processing PDF: {e}")
def page_2():
st.title("Page 2: Query and Use Processed Collections")
if "image_collection" in st.session_state and "text_collection" in st.session_state and "video_collection" in st.session_state:
image_collection = st.session_state.image_collection
text_collection = st.session_state.text_collection
video_collection = st.session_state.video_collection
st.success("Collections loaded successfully.")
query = st.text_input("Enter your query", value="Example Query")
if query:
inputs = format_prompt_inputs(image_collection, text_collection, video_collection, query)
texts = inputs["texts"]
image_data_1 = inputs["image_data_1"]
if image_data_1:
image_data_1 = process_image_and_get_description(image_data_1)
response = intiate_convo(query, image_data_1, texts)
st.write("Response:", response)
st.markdown("### Images")
display_images(image_collection, query, max_distance=1.55, debug=True)
st.markdown("### Videos")
frame = inputs["frame"]
if frame:
video_path = f"StockVideos-CC0/{os.path.basename(frame).split('/')[0]}.mp4"
if os.path.exists(video_path):
st.video(video_path)
else:
st.write("No related videos found.")
else:
st.error("Collections not found in session state. Please process the PDF on Page 1.")
# --- Navigation ---
PAGES = {
"Upload and Process PDF": page_1,
"Query and Use Processed Collections": page_2
}
# Select page
selected_page = st.sidebar.selectbox("Choose a page", options=list(PAGES.keys()))
# Render selected page
PAGES[selected_page]() |