Last commit not found
import os | |
import streamlit as st | |
from PIL import Image as PILImage | |
from PIL import Image as pilImage | |
import base64 | |
import io | |
import chromadb | |
from initate import process_pdf | |
from utils.llm_ag import intiate_convo | |
from utils.doi import process_image_and_get_description | |
path = "mm_vdb2" | |
client = chromadb.PersistentClient(path=path) | |
import streamlit as st | |
from PIL import Image as PILImage | |
def display_images(image_collection, query_text, max_distance=None, debug=False): | |
""" | |
Display images in a Streamlit app based on a query. | |
Args: | |
image_collection: The image collection object for querying. | |
query_text (str): The text query for images. | |
max_distance (float, optional): Maximum allowable distance for filtering. | |
debug (bool, optional): Whether to print debug information. | |
""" | |
results = image_collection.query( | |
query_texts=[query_text], | |
n_results=10, | |
include=['uris', 'distances'] | |
) | |
uris = results['uris'][0] | |
distances = results['distances'][0] | |
# Combine uris and distances, then sort by URI in ascending order | |
sorted_results = sorted(zip(uris, distances), key=lambda x: x[0]) | |
# Filter and display images | |
for uri, distance in sorted_results: | |
if max_distance is None or distance <= max_distance: | |
if debug: | |
st.write(f"URI: {uri} - Distance: {distance}") | |
try: | |
img = PILImage.open(uri) | |
st.image(img, width=300) | |
except Exception as e: | |
st.error(f"Error loading image {uri}: {e}") | |
else: | |
if debug: | |
st.write(f"URI: {uri} - Distance: {distance} (Filtered out)") | |
def display_videos_streamlit(video_collection, query_text, max_distance=None, max_results=5, debug=False): | |
""" | |
Display videos in a Streamlit app based on a query. | |
Args: | |
video_collection: The video collection object for querying. | |
query_text (str): The text query for videos. | |
max_distance (float, optional): Maximum allowable distance for filtering. | |
max_results (int, optional): Maximum number of results to display. | |
debug (bool, optional): Whether to print debug information. | |
""" | |
# Deduplication set | |
displayed_videos = set() | |
# Query the video collection with the specified text | |
results = video_collection.query( | |
query_texts=[query_text], | |
n_results=max_results, # Adjust the number of results if needed | |
include=['uris', 'distances', 'metadatas'] | |
) | |
# Extract URIs, distances, and metadatas from the result | |
uris = results['uris'][0] | |
distances = results['distances'][0] | |
metadatas = results['metadatas'][0] | |
# Display the videos that meet the distance criteria | |
for uri, distance, metadata in zip(uris, distances, metadatas): | |
video_uri = metadata['video_uri'] | |
# Check if a max_distance filter is applied and the distance is within the allowed range | |
if (max_distance is None or distance <= max_distance) and video_uri not in displayed_videos: | |
if debug: | |
st.write(f"URI: {uri} - Video URI: {video_uri} - Distance: {distance}") | |
st.video(video_uri) # Display video in Streamlit | |
displayed_videos.add(video_uri) # Add to the set to prevent duplication | |
else: | |
if debug: | |
st.write(f"URI: {uri} - Video URI: {video_uri} - Distance: {distance} (Filtered out)") | |
def image_uris(image_collection,query_text, max_distance=None, max_results=5): | |
results = image_collection.query( | |
query_texts=[query_text], | |
n_results=max_results, | |
include=['uris', 'distances'] | |
) | |
filtered_uris = [] | |
for uri, distance in zip(results['uris'][0], results['distances'][0]): | |
if max_distance is None or distance <= max_distance: | |
filtered_uris.append(uri) | |
return filtered_uris | |
def text_uris(text_collection,query_text, max_distance=None, max_results=5): | |
results = text_collection.query( | |
query_texts=[query_text], | |
n_results=max_results, | |
include=['documents', 'distances'] | |
) | |
filtered_texts = [] | |
for doc, distance in zip(results['documents'][0], results['distances'][0]): | |
if max_distance is None or distance <= max_distance: | |
filtered_texts.append(doc) | |
return filtered_texts | |
def frame_uris(video_collection,query_text, max_distance=None, max_results=5): | |
results = video_collection.query( | |
query_texts=[query_text], | |
n_results=max_results, | |
include=['uris', 'distances'] | |
) | |
filtered_uris = [] | |
seen_folders = set() | |
for uri, distance in zip(results['uris'][0], results['distances'][0]): | |
if max_distance is None or distance <= max_distance: | |
folder = os.path.dirname(uri) | |
if folder not in seen_folders: | |
filtered_uris.append(uri) | |
seen_folders.add(folder) | |
if len(filtered_uris) == max_results: | |
break | |
return filtered_uris | |
def image_uris2(image_collection2,query_text, max_distance=None, max_results=5): | |
results = image_collection2.query( | |
query_texts=[query_text], | |
n_results=max_results, | |
include=['uris', 'distances'] | |
) | |
filtered_uris = [] | |
for uri, distance in zip(results['uris'][0], results['distances'][0]): | |
if max_distance is None or distance <= max_distance: | |
filtered_uris.append(uri) | |
return filtered_uris | |
def format_prompt_inputs(image_collection, text_collection, video_collection, user_query): | |
frame_candidates = frame_uris(video_collection, user_query, max_distance=1.55) | |
image_candidates = image_uris(image_collection, user_query, max_distance=1.5) | |
texts = text_uris(text_collection, user_query, max_distance=1.3) | |
inputs = {"query": user_query, "texts": texts} | |
frame = frame_candidates[0] if frame_candidates else "" | |
inputs["frame"] = frame | |
if image_candidates: | |
image = image_candidates[0] | |
with PILImage.open(image) as img: | |
img = img.resize((img.width // 6, img.height // 6)) | |
img = img.convert("L") | |
with io.BytesIO() as output: | |
img.save(output, format="JPEG", quality=60) | |
compressed_image_data = output.getvalue() | |
inputs["image_data_1"] = base64.b64encode(compressed_image_data).decode('utf-8') | |
else: | |
inputs["image_data_1"] = "" | |
return inputs | |
def page_1(): | |
st.title("Page 1: Upload and Process PDF") | |
uploaded_file = st.file_uploader("Upload a PDF file", type=["pdf"]) | |
if uploaded_file: | |
pdf_path = f"/tmp/{uploaded_file.name}" | |
with open(pdf_path, "wb") as f: | |
f.write(uploaded_file.getbuffer()) | |
try: | |
image_collection, text_collection, video_collection = process_pdf(pdf_path) | |
st.session_state.image_collection = image_collection | |
st.session_state.text_collection = text_collection | |
st.session_state.video_collection = video_collection | |
st.success("PDF processed successfully! Collections saved to session state.") | |
except Exception as e: | |
st.error(f"Error processing PDF: {e}") | |
def page_2(): | |
st.title("Page 2: Query and Use Processed Collections") | |
if "image_collection" in st.session_state and "text_collection" in st.session_state and "video_collection" in st.session_state: | |
image_collection = st.session_state.image_collection | |
text_collection = st.session_state.text_collection | |
video_collection = st.session_state.video_collection | |
st.success("Collections loaded successfully.") | |
query = st.text_input("Enter your query", value="Example Query") | |
if query: | |
inputs = format_prompt_inputs(image_collection, text_collection, video_collection, query) | |
texts = inputs["texts"] | |
image_data_1 = inputs["image_data_1"] | |
if image_data_1: | |
image_data_1 = process_image_and_get_description(image_data_1) | |
response = intiate_convo(query, image_data_1, texts) | |
st.write("Response:", response) | |
st.markdown("### Images") | |
display_images(image_collection, query, max_distance=1.55, debug=True) | |
st.markdown("### Videos") | |
frame = inputs["frame"] | |
if frame: | |
video_path = f"StockVideos-CC0/{os.path.basename(frame).split('/')[0]}.mp4" | |
if os.path.exists(video_path): | |
st.video(video_path) | |
else: | |
st.write("No related videos found.") | |
else: | |
st.error("Collections not found in session state. Please process the PDF on Page 1.") | |
# --- Navigation --- | |
PAGES = { | |
"Upload and Process PDF": page_1, | |
"Query and Use Processed Collections": page_2 | |
} | |
# Select page | |
selected_page = st.sidebar.selectbox("Choose a page", options=list(PAGES.keys())) | |
# Render selected page | |
PAGES[selected_page]() |