|
import os |
|
import streamlit as st |
|
from PIL import Image as PILImage |
|
from PIL import Image as pilImage |
|
import base64 |
|
import io |
|
import chromadb |
|
from initate import process_pdf |
|
from utils.llm_ag import intiate_convo |
|
from utils.doi import process_image_and_get_description |
|
|
|
path = "mm_vdb2" |
|
client = chromadb.PersistentClient(path=path) |
|
import streamlit as st |
|
from PIL import Image as PILImage |
|
|
|
def add_background_image(image_path): |
|
with open(image_path, "rb") as image_file: |
|
base64_image = base64.b64encode(image_file.read()).decode() |
|
css = f""" |
|
<style> |
|
.stApp {{ |
|
background-image: url("data:image/png;base64,{base64_image}"); |
|
background-size: cover; |
|
background-repeat: no-repeat; |
|
background-attachment: fixed; |
|
}} |
|
</style> |
|
""" |
|
st.markdown(css, unsafe_allow_html=True) |
|
|
|
|
|
|
|
|
|
def display_images(image_collection, query_text, max_distance=None, debug=False): |
|
""" |
|
Display images in a Streamlit app based on a query. |
|
Args: |
|
image_collection: The image collection object for querying. |
|
query_text (str): The text query for images. |
|
max_distance (float, optional): Maximum allowable distance for filtering. |
|
debug (bool, optional): Whether to print debug information. |
|
""" |
|
results = image_collection.query( |
|
query_texts=[query_text], |
|
n_results=10, |
|
include=['uris', 'distances'] |
|
) |
|
|
|
uris = results['uris'][0] |
|
distances = results['distances'][0] |
|
|
|
|
|
sorted_results = sorted(zip(uris, distances), key=lambda x: x[0]) |
|
|
|
|
|
cols = st.columns(3) |
|
|
|
for i, (uri, distance) in enumerate(sorted_results): |
|
if max_distance is None or distance <= max_distance: |
|
try: |
|
img = PILImage.open(uri) |
|
with cols[i % 3]: |
|
st.image(img, use_container_width = True) |
|
except Exception as e: |
|
st.error(f"Error loading image: {e}") |
|
|
|
def display_videos_streamlit(video_collection, query_text, max_distance=None, max_results=5, debug=False): |
|
""" |
|
Display videos in a Streamlit app based on a query. |
|
|
|
Args: |
|
video_collection: The video collection object for querying. |
|
query_text (str): The text query for videos. |
|
max_distance (float, optional): Maximum allowable distance for filtering. |
|
max_results (int, optional): Maximum number of results to display. |
|
debug (bool, optional): Whether to print debug information. |
|
""" |
|
|
|
displayed_videos = set() |
|
|
|
|
|
results = video_collection.query( |
|
query_texts=[query_text], |
|
n_results=max_results, |
|
include=['uris', 'distances', 'metadatas'] |
|
) |
|
|
|
|
|
uris = results['uris'][0] |
|
distances = results['distances'][0] |
|
metadatas = results['metadatas'][0] |
|
|
|
|
|
for uri, distance, metadata in zip(uris, distances, metadatas): |
|
video_uri = metadata['video_uri'] |
|
|
|
|
|
if (max_distance is None or distance <= max_distance) and video_uri not in displayed_videos: |
|
if debug: |
|
st.write(f"URI: {uri} - Video URI: {video_uri} - Distance: {distance}") |
|
st.video(video_uri) |
|
displayed_videos.add(video_uri) |
|
else: |
|
if debug: |
|
st.write(f"URI: {uri} - Video URI: {video_uri} - Distance: {distance} (Filtered out)") |
|
|
|
|
|
def image_uris(image_collection,query_text, max_distance=None, max_results=5): |
|
results = image_collection.query( |
|
query_texts=[query_text], |
|
n_results=max_results, |
|
include=['uris', 'distances'] |
|
) |
|
|
|
filtered_uris = [] |
|
for uri, distance in zip(results['uris'][0], results['distances'][0]): |
|
if max_distance is None or distance <= max_distance: |
|
filtered_uris.append(uri) |
|
|
|
return filtered_uris |
|
|
|
def text_uris(text_collection,query_text, max_distance=None, max_results=5): |
|
results = text_collection.query( |
|
query_texts=[query_text], |
|
n_results=max_results, |
|
include=['documents', 'distances'] |
|
) |
|
|
|
filtered_texts = [] |
|
for doc, distance in zip(results['documents'][0], results['distances'][0]): |
|
if max_distance is None or distance <= max_distance: |
|
filtered_texts.append(doc) |
|
|
|
return filtered_texts |
|
|
|
def frame_uris(video_collection,query_text, max_distance=None, max_results=5): |
|
results = video_collection.query( |
|
query_texts=[query_text], |
|
n_results=max_results, |
|
include=['uris', 'distances'] |
|
) |
|
|
|
filtered_uris = [] |
|
seen_folders = set() |
|
|
|
for uri, distance in zip(results['uris'][0], results['distances'][0]): |
|
if max_distance is None or distance <= max_distance: |
|
folder = os.path.dirname(uri) |
|
if folder not in seen_folders: |
|
filtered_uris.append(uri) |
|
seen_folders.add(folder) |
|
|
|
if len(filtered_uris) == max_results: |
|
break |
|
|
|
return filtered_uris |
|
|
|
def image_uris2(image_collection2,query_text, max_distance=None, max_results=5): |
|
results = image_collection2.query( |
|
query_texts=[query_text], |
|
n_results=max_results, |
|
include=['uris', 'distances'] |
|
) |
|
|
|
filtered_uris = [] |
|
for uri, distance in zip(results['uris'][0], results['distances'][0]): |
|
if max_distance is None or distance <= max_distance: |
|
filtered_uris.append(uri) |
|
|
|
return filtered_uris |
|
|
|
|
|
def format_prompt_inputs(image_collection, text_collection, video_collection, user_query): |
|
frame_candidates = frame_uris(video_collection, user_query, max_distance=1.55) |
|
image_candidates = image_uris(image_collection, user_query, max_distance=1.5) |
|
texts = text_uris(text_collection, user_query, max_distance=1.3) |
|
|
|
inputs = {"query": user_query, "texts": texts} |
|
frame = frame_candidates[0] if frame_candidates else "" |
|
inputs["frame"] = frame |
|
|
|
if image_candidates: |
|
image = image_candidates[0] |
|
with PILImage.open(image) as img: |
|
img = img.resize((img.width // 6, img.height // 6)) |
|
img = img.convert("L") |
|
with io.BytesIO() as output: |
|
img.save(output, format="JPEG", quality=60) |
|
compressed_image_data = output.getvalue() |
|
|
|
inputs["image_data_1"] = base64.b64encode(compressed_image_data).decode('utf-8') |
|
else: |
|
inputs["image_data_1"] = "" |
|
|
|
return inputs |
|
|
|
import time |
|
|
|
def page_1(): |
|
add_background_image("bg3.jpg") |
|
|
|
st.markdown(""" |
|
<svg width="600" height="100"> |
|
<text x="50%" y="50%" font-family="San serif" font-size="42px" fill="Black" text-anchor="middle" stroke="white" |
|
stroke-width="0.3" stroke-linejoin="round">ADMIN - UPLOAD |
|
</text> |
|
</svg> |
|
""", unsafe_allow_html=True) |
|
uploaded_file = st.file_uploader("Upload a PDF file", type=["pdf"]) |
|
if uploaded_file: |
|
pdf_path = f"/tmp/{uploaded_file.name}" |
|
with open(pdf_path, "wb") as f: |
|
f.write(uploaded_file.getbuffer()) |
|
|
|
|
|
with st.spinner("Processing PDF... Please wait."): |
|
try: |
|
|
|
time.sleep(1) |
|
|
|
|
|
st.text("Extracting content from PDF...") |
|
image_collection, text_collection, video_collection = process_pdf(pdf_path) |
|
st.session_state.image_collection = image_collection |
|
st.session_state.text_collection = text_collection |
|
st.session_state.video_collection = video_collection |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
time.sleep(1) |
|
|
|
st.success("PDF processed successfully! Collections saved to session state.") |
|
except Exception as e: |
|
st.error(f"Error processing PDF: {e}") |
|
|
|
def page_2(): |
|
add_background_image("bg3.jpg") |
|
st.markdown(""" |
|
<div style="text-align: left;"> |
|
<svg width="600" height="100"> |
|
<text x="0" y="50%" font-family="San serif" font-size="42px" fill="Black" stroke="white" |
|
stroke-width="0.1" stroke-linejoin="round">Poss Assistant |
|
</text> |
|
</svg> |
|
</div> |
|
""", unsafe_allow_html=True) |
|
|
|
if "image_collection" in st.session_state and "text_collection" in st.session_state and "video_collection" in st.session_state: |
|
image_collection = st.session_state.image_collection |
|
text_collection = st.session_state.text_collection |
|
video_collection = st.session_state.video_collection |
|
st.success("Collections loaded successfully.") |
|
|
|
query = st.text_input("Enter your query", value="Example Query") |
|
if query: |
|
inputs = format_prompt_inputs(image_collection, text_collection, video_collection, query) |
|
texts = inputs["texts"] |
|
image_data_1 = inputs["image_data_1"] |
|
|
|
if image_data_1: |
|
image_data_1 = process_image_and_get_description(image_data_1) |
|
|
|
response = intiate_convo(query, image_data_1, texts) |
|
|
|
st.markdown("### Assistant's Response") |
|
st.markdown(response) |
|
|
|
st.markdown("### Images") |
|
display_images(image_collection, query, max_distance=1.55, debug=False) |
|
|
|
st.markdown("### Videos") |
|
frame = inputs["frame"] |
|
if frame: |
|
directory_name = frame.split('/')[1] |
|
video_path = f"videos_flattened/{directory_name}.mp4" |
|
if os.path.exists(video_path): |
|
st.video(video_path) |
|
else: |
|
st.write("No related videos found.") |
|
else: |
|
st.error("Collections not found in session state. Please process the PDF on Page 1.") |
|
|
|
|
|
|
|
PAGES = { |
|
"Upload and Process PDF": page_1, |
|
"Query and Use Processed Collections": page_2 |
|
} |
|
|
|
|
|
selected_page = st.sidebar.selectbox("Choose a page", options=list(PAGES.keys())) |
|
|
|
|
|
PAGES[selected_page]() |