Titan / app.py
NEXAS's picture
Update app.py
5b6e1ee verified
import os
import streamlit as st
from PIL import Image as PILImage
from PIL import Image as pilImage
import base64
import io
import chromadb
from initate import process_pdf
from utils.llm_ag import intiate_convo
from utils.doi import process_image_and_get_description
path = "mm_vdb2"
client = chromadb.PersistentClient(path=path)
import streamlit as st
from PIL import Image as PILImage
def add_background_image(image_path):
with open(image_path, "rb") as image_file:
base64_image = base64.b64encode(image_file.read()).decode()
css = f"""
<style>
.stApp {{
background-image: url("data:image/png;base64,{base64_image}");
background-size: cover;
background-repeat: no-repeat;
background-attachment: fixed;
}}
</style>
"""
st.markdown(css, unsafe_allow_html=True)
# Call the function with your image path
def display_images(image_collection, query_text, max_distance=None, debug=False):
"""
Display images in a Streamlit app based on a query.
Args:
image_collection: The image collection object for querying.
query_text (str): The text query for images.
max_distance (float, optional): Maximum allowable distance for filtering.
debug (bool, optional): Whether to print debug information.
"""
results = image_collection.query(
query_texts=[query_text],
n_results=10,
include=['uris', 'distances']
)
uris = results['uris'][0]
distances = results['distances'][0]
# Combine uris and distances, then sort by URI in ascending order
sorted_results = sorted(zip(uris, distances), key=lambda x: x[0])
# Display images side by side, 3 images per row
cols = st.columns(3) # Create 3 columns for the layout
for i, (uri, distance) in enumerate(sorted_results):
if max_distance is None or distance <= max_distance:
try:
img = PILImage.open(uri)
with cols[i % 3]: # Use modulo to cycle through columns
st.image(img, use_container_width = True)
except Exception as e:
st.error(f"Error loading image: {e}")
def display_videos_streamlit(video_collection, query_text, max_distance=None, max_results=5, debug=False):
"""
Display videos in a Streamlit app based on a query.
Args:
video_collection: The video collection object for querying.
query_text (str): The text query for videos.
max_distance (float, optional): Maximum allowable distance for filtering.
max_results (int, optional): Maximum number of results to display.
debug (bool, optional): Whether to print debug information.
"""
# Deduplication set
displayed_videos = set()
# Query the video collection with the specified text
results = video_collection.query(
query_texts=[query_text],
n_results=max_results, # Adjust the number of results if needed
include=['uris', 'distances', 'metadatas']
)
# Extract URIs, distances, and metadatas from the result
uris = results['uris'][0]
distances = results['distances'][0]
metadatas = results['metadatas'][0]
# Display the videos that meet the distance criteria
for uri, distance, metadata in zip(uris, distances, metadatas):
video_uri = metadata['video_uri']
# Check if a max_distance filter is applied and the distance is within the allowed range
if (max_distance is None or distance <= max_distance) and video_uri not in displayed_videos:
if debug:
st.write(f"URI: {uri} - Video URI: {video_uri} - Distance: {distance}")
st.video(video_uri) # Display video in Streamlit
displayed_videos.add(video_uri) # Add to the set to prevent duplication
else:
if debug:
st.write(f"URI: {uri} - Video URI: {video_uri} - Distance: {distance} (Filtered out)")
def image_uris(image_collection,query_text, max_distance=None, max_results=5):
results = image_collection.query(
query_texts=[query_text],
n_results=max_results,
include=['uris', 'distances']
)
filtered_uris = []
for uri, distance in zip(results['uris'][0], results['distances'][0]):
if max_distance is None or distance <= max_distance:
filtered_uris.append(uri)
return filtered_uris
def text_uris(text_collection,query_text, max_distance=None, max_results=5):
results = text_collection.query(
query_texts=[query_text],
n_results=max_results,
include=['documents', 'distances']
)
filtered_texts = []
for doc, distance in zip(results['documents'][0], results['distances'][0]):
if max_distance is None or distance <= max_distance:
filtered_texts.append(doc)
return filtered_texts
def frame_uris(video_collection,query_text, max_distance=None, max_results=5):
results = video_collection.query(
query_texts=[query_text],
n_results=max_results,
include=['uris', 'distances']
)
filtered_uris = []
seen_folders = set()
for uri, distance in zip(results['uris'][0], results['distances'][0]):
if max_distance is None or distance <= max_distance:
folder = os.path.dirname(uri)
if folder not in seen_folders:
filtered_uris.append(uri)
seen_folders.add(folder)
if len(filtered_uris) == max_results:
break
return filtered_uris
def image_uris2(image_collection2,query_text, max_distance=None, max_results=5):
results = image_collection2.query(
query_texts=[query_text],
n_results=max_results,
include=['uris', 'distances']
)
filtered_uris = []
for uri, distance in zip(results['uris'][0], results['distances'][0]):
if max_distance is None or distance <= max_distance:
filtered_uris.append(uri)
return filtered_uris
def format_prompt_inputs(image_collection, text_collection, video_collection, user_query):
frame_candidates = frame_uris(video_collection, user_query, max_distance=1.55)
image_candidates = image_uris(image_collection, user_query, max_distance=1.5)
texts = text_uris(text_collection, user_query, max_distance=1.3)
inputs = {"query": user_query, "texts": texts}
frame = frame_candidates[0] if frame_candidates else ""
inputs["frame"] = frame
if image_candidates:
image = image_candidates[0]
with PILImage.open(image) as img:
img = img.resize((img.width // 6, img.height // 6))
img = img.convert("L")
with io.BytesIO() as output:
img.save(output, format="JPEG", quality=60)
compressed_image_data = output.getvalue()
inputs["image_data_1"] = base64.b64encode(compressed_image_data).decode('utf-8')
else:
inputs["image_data_1"] = ""
return inputs
import time # To simulate delays during processing
def page_1():
add_background_image("bg3.jpg")
# st.set_page_config(layout='wide', page_title="Virtual Tutor")
st.markdown("""
<svg width="600" height="100">
<text x="50%" y="50%" font-family="San serif" font-size="42px" fill="Black" text-anchor="middle" stroke="white"
stroke-width="0.3" stroke-linejoin="round">ADMIN - UPLOAD
</text>
</svg>
""", unsafe_allow_html=True)
uploaded_file = st.file_uploader("Upload a PDF file", type=["pdf"])
if uploaded_file:
pdf_path = f"/tmp/{uploaded_file.name}"
with open(pdf_path, "wb") as f:
f.write(uploaded_file.getbuffer())
# Display a spinner while processing
with st.spinner("Processing PDF... Please wait."):
try:
# Simulate processing stages with a delay (this is just an example)
time.sleep(1) # Simulate a step in the processing
# Step 1: Process images, texts, and videos
st.text("Extracting content from PDF...")
image_collection, text_collection, video_collection = process_pdf(pdf_path)
st.session_state.image_collection = image_collection
st.session_state.text_collection = text_collection
st.session_state.video_collection = video_collection
# import shutil
# # Path to the folder you want to download
# folder_path = "mm_vdb2"
# zip_path = "mm_vdb2.zip"
# # Compress the folder
# shutil.make_archive(base_name="mm_vdb2", format="zip", root_dir=folder_path)
# with open(zip_path, "rb") as file:
# st.download_button(label="Download mm_vdb2.zip", data=file, file_name="mm_vdb2.zip")
# Simulate a delay for finalizing (if needed)
time.sleep(1) # Simulate final step
st.success("PDF processed successfully! Collections saved to session state.")
except Exception as e:
st.error(f"Error processing PDF: {e}")
def page_2():
add_background_image("bg3.jpg")
st.markdown("""
<div style="text-align: left;">
<svg width="600" height="100">
<text x="0" y="50%" font-family="San serif" font-size="42px" fill="Black" stroke="white"
stroke-width="0.1" stroke-linejoin="round">Poss Assistant
</text>
</svg>
</div>
""", unsafe_allow_html=True)
if "image_collection" in st.session_state and "text_collection" in st.session_state and "video_collection" in st.session_state:
image_collection = st.session_state.image_collection
text_collection = st.session_state.text_collection
video_collection = st.session_state.video_collection
st.success("Collections loaded successfully.")
query = st.text_input("Enter your query", value="Example Query")
if query:
inputs = format_prompt_inputs(image_collection, text_collection, video_collection, query)
texts = inputs["texts"]
image_data_1 = inputs["image_data_1"]
if image_data_1:
image_data_1 = process_image_and_get_description(image_data_1)
response = intiate_convo(query, image_data_1, texts)
# Display the response in Markdown format
st.markdown("### Assistant's Response")
st.markdown(response)
st.markdown("### Images")
display_images(image_collection, query, max_distance=1.55, debug=False)
st.markdown("### Videos")
frame = inputs["frame"]
if frame:
directory_name = frame.split('/')[1]
video_path = f"videos_flattened/{directory_name}.mp4"
if os.path.exists(video_path):
st.video(video_path)
else:
st.write("No related videos found.")
else:
st.error("Collections not found in session state. Please process the PDF on Page 1.")
# --- Navigation ---
PAGES = {
"Upload and Process PDF": page_1,
"Query and Use Processed Collections": page_2
}
# Select page
selected_page = st.sidebar.selectbox("Choose a page", options=list(PAGES.keys()))
# Render selected page
PAGES[selected_page]()