Titan / app.py
NEXAS's picture
Update app.py
ff25b71 verified
raw
history blame
12.5 kB
import os
import zipfile
import time
import streamlit as st
from PIL import Image as PILImage
from PIL import Image as pilImage
import base64
import io
import chromadb
from initate import process_pdf
from utils.llm_ag import intiate_convo
from utils.doi import process_image_and_get_description
path = "mm_vdb2"
client = chromadb.PersistentClient(path=path)
import streamlit as st
from PIL import Image as PILImage
def display_images(image_collection, query_text, max_distance=None, debug=False):
"""
Display images in a Streamlit app based on a query.
Args:
image_collection: The image collection object for querying.
query_text (str): The text query for images.
max_distance (float, optional): Maximum allowable distance for filtering.
debug (bool, optional): Whether to print debug information.
"""
results = image_collection.query(
query_texts=[query_text],
n_results=10,
include=['uris', 'distances']
)
uris = results['uris'][0]
distances = results['distances'][0]
# Combine uris and distances, then sort by URI in ascending order
sorted_results = sorted(zip(uris, distances), key=lambda x: x[0])
# Filter and display images
for uri, distance in sorted_results:
if max_distance is None or distance <= max_distance:
if debug:
st.write(f"URI: {uri} - Distance: {distance}")
try:
img = PILImage.open(uri)
st.image(img, width=300)
except Exception as e:
st.error(f"Error loading image {uri}: {e}")
else:
if debug:
st.write(f"URI: {uri} - Distance: {distance} (Filtered out)")
def display_videos_streamlit(video_collection, query_text, max_distance=None, max_results=5, debug=False):
"""
Display videos in a Streamlit app based on a query.
Args:
video_collection: The video collection object for querying.
query_text (str): The text query for videos.
max_distance (float, optional): Maximum allowable distance for filtering.
max_results (int, optional): Maximum number of results to display.
debug (bool, optional): Whether to print debug information.
"""
# Deduplication set
displayed_videos = set()
# Query the video collection with the specified text
results = video_collection.query(
query_texts=[query_text],
n_results=max_results, # Adjust the number of results if needed
include=['uris', 'distances', 'metadatas']
)
# Extract URIs, distances, and metadatas from the result
uris = results['uris'][0]
distances = results['distances'][0]
metadatas = results['metadatas'][0]
# Display the videos that meet the distance criteria
for uri, distance, metadata in zip(uris, distances, metadatas):
video_uri = metadata['video_uri']
# Check if a max_distance filter is applied and the distance is within the allowed range
if (max_distance is None or distance <= max_distance) and video_uri not in displayed_videos:
if debug:
st.write(f"URI: {uri} - Video URI: {video_uri} - Distance: {distance}")
st.video(video_uri) # Display video in Streamlit
displayed_videos.add(video_uri) # Add to the set to prevent duplication
else:
if debug:
st.write(f"URI: {uri} - Video URI: {video_uri} - Distance: {distance} (Filtered out)")
def image_uris(image_collection,query_text, max_distance=None, max_results=5):
results = image_collection.query(
query_texts=[query_text],
n_results=max_results,
include=['uris', 'distances']
)
filtered_uris = []
for uri, distance in zip(results['uris'][0], results['distances'][0]):
if max_distance is None or distance <= max_distance:
filtered_uris.append(uri)
return filtered_uris
def text_uris(text_collection,query_text, max_distance=None, max_results=5):
results = text_collection.query(
query_texts=[query_text],
n_results=max_results,
include=['documents', 'distances']
)
filtered_texts = []
for doc, distance in zip(results['documents'][0], results['distances'][0]):
if max_distance is None or distance <= max_distance:
filtered_texts.append(doc)
return filtered_texts
def frame_uris(video_collection,query_text, max_distance=None, max_results=5):
results = video_collection.query(
query_texts=[query_text],
n_results=max_results,
include=['uris', 'distances']
)
filtered_uris = []
seen_folders = set()
for uri, distance in zip(results['uris'][0], results['distances'][0]):
if max_distance is None or distance <= max_distance:
folder = os.path.dirname(uri)
if folder not in seen_folders:
filtered_uris.append(uri)
seen_folders.add(folder)
if len(filtered_uris) == max_results:
break
return filtered_uris
def image_uris2(image_collection2,query_text, max_distance=None, max_results=5):
results = image_collection2.query(
query_texts=[query_text],
n_results=max_results,
include=['uris', 'distances']
)
filtered_uris = []
for uri, distance in zip(results['uris'][0], results['distances'][0]):
if max_distance is None or distance <= max_distance:
filtered_uris.append(uri)
return filtered_uris
def format_prompt_inputs(image_collection, text_collection, video_collection, user_query):
frame_candidates = frame_uris(video_collection, user_query, max_distance=1.55)
image_candidates = image_uris(image_collection, user_query, max_distance=1.5)
texts = text_uris(text_collection, user_query, max_distance=1.3)
inputs = {"query": user_query, "texts": texts}
frame = frame_candidates[0] if frame_candidates else ""
inputs["frame"] = frame
if image_candidates:
image = image_candidates[0]
with PILImage.open(image) as img:
img = img.resize((img.width // 2, img.height // 2))
# img = img.convert("L")
with io.BytesIO() as output:
img.save(output, format="JPEG", quality=60)
compressed_image_data = output.getvalue()
inputs["image_data_1"] = base64.b64encode(compressed_image_data).decode('utf-8')
else:
inputs["image_data_1"] = ""
return inputs
def unzip_file(zip_path, extract_to):
"""
Unzips a zip file to the specified directory.
Args:
zip_path (str): Path to the zip file.
extract_to (str): Directory where the contents should be extracted.
"""
try:
# Ensure the destination directory exists
os.makedirs(extract_to, exist_ok=True)
# Open the zip file
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
# Extract all the contents
zip_ref.extractall(extract_to)
return True
except Exception as e:
print(f"An error occurred: {e}")
return False
def process_pdf(pdf_path):
# Placeholder function to simulate PDF processing
# Replace this with actual PDF processing logic, such as extracting text, images, etc.
time.sleep(2) # Simulating processing delay
return "image_collection", "text_collection", "video_collection" # Replace with actual collections
def page_1():
st.title("Page 1: Upload and Process Videos and PDFs")
# File uploader for multiple zip files containing videos
uploaded_video_zips = st.file_uploader("Upload ZIP files containing videos", type=["zip"], accept_multiple_files=True)
# File uploader for PDF files
uploaded_pdf_files = st.file_uploader("Upload PDF files", type=["pdf"], accept_multiple_files=True)
# Button to trigger processing
if (uploaded_video_zips or uploaded_pdf_files) and st.button("Process Files"):
# Temporary folder to store extracted files
temp_folder = "/tmp/extracted_files"
os.makedirs(temp_folder, exist_ok=True)
# Progress bar
progress_bar = st.progress(0)
status_text = st.empty()
try:
total_files = len(uploaded_video_zips) + len(uploaded_pdf_files)
files_processed = 0
progress_step = 100 / total_files if total_files > 0 else 0
# Process video zip files
for uploaded_file in uploaded_video_zips:
zip_path = f"/tmp/{uploaded_file.name}"
with open(zip_path, "wb") as f:
f.write(uploaded_file.getbuffer())
# Extract the content from the zip file
folder_name = os.path.splitext(uploaded_file.name)[0]
extract_to = os.path.join(temp_folder, folder_name)
if unzip_file(zip_path, extract_to):
files_processed += 1
progress_bar.progress(files_processed * progress_step)
status_text.text(f"Extracting: {uploaded_file.name} ({files_processed}/{total_files})")
# Process PDF files
for uploaded_pdf in uploaded_pdf_files:
pdf_path = f"/tmp/{uploaded_pdf.name}"
with open(pdf_path, "wb") as f:
f.write(uploaded_pdf.getbuffer())
# Simulate PDF processing (replace with actual PDF processing logic)
files_processed += 1
progress_bar.progress(files_processed * progress_step)
status_text.text(f"Processing PDF: {uploaded_pdf.name} ({files_processed}/{total_files})")
# Call your actual PDF processing function here, e.g.
image_collection, text_collection, video_collection = process_pdf(pdf_path,temp_folder)
# Save collections to session state
st.session_state.image_collection = image_collection
st.session_state.text_collection = text_collection
st.session_state.video_collection = video_collection
# Update status after extraction and processing
status_text.text("Extraction and processing completed successfully!")
st.success("Videos and PDFs processed successfully! Collections saved to session state.")
except Exception as e:
progress_bar.progress(0)
status_text.text("")
st.error(f"Error processing files: {e}")
def page_2():
st.title("Page 2: Query and Use Processed Collections")
if "image_collection" in st.session_state and "text_collection" in st.session_state and "video_collection" in st.session_state:
image_collection = st.session_state.image_collection
text_collection = st.session_state.text_collection
video_collection = st.session_state.video_collection
st.success("Collections loaded successfully.")
query = st.text_input("Enter your query", value="Example Query")
if query:
inputs = format_prompt_inputs(image_collection, text_collection, video_collection, query)
texts = inputs["texts"]
image_data_1 = inputs["image_data_1"]
if image_data_1:
image_data_1 = process_image_and_get_description(image_data_1)
response = intiate_convo(query, image_data_1, texts)
st.write("Response:", response)
st.markdown("### Images")
display_images(image_collection, query, max_distance=1.55, debug=True)
st.markdown("### Videos")
display_videos_streamlit(video_collection, query_text=query, max_distance=None, max_results=5, debug=False)
frame = inputs["frame"]
if frame:
video_path = f"video/StockVideos-CC0/{os.path.basename(frame).split('/')[0]}.mp4"
if os.path.exists(video_path):
st.video(video_path)
else:
st.write("No related videos found.")
else:
st.error("Collections not found in session state. Please process the PDF on Page 1.")
# --- Navigation ---
PAGES = {
"Upload and Process PDF": page_1,
"Query and Use Processed Collections": page_2
}
# Select page
selected_page = st.sidebar.selectbox("Choose a page", options=list(PAGES.keys()))
# Render selected page
PAGES[selected_page]()