File size: 9,684 Bytes
ae74f7a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
db3c129
ae74f7a
db3c129
ae74f7a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
db3c129
 
ae74f7a
 
 
 
 
 
 
 
 
db3c129
 
 
4ed9090
db3c129
 
 
 
 
 
 
4ed9090
db3c129
 
 
 
 
 
 
 
 
4ed9090
 
 
 
db3c129
 
4ed9090
db3c129
 
 
ae74f7a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
db3c129
ae74f7a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
03d82bf
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
import os
import streamlit as st
from PIL import Image as PILImage
from PIL import Image as pilImage
import base64
import io
import chromadb
from initate import process_pdf
from utils.llm_ag import intiate_convo
from utils.doi import process_image_and_get_description

path = "mm_vdb2"
client = chromadb.PersistentClient(path=path)
import streamlit as st
from PIL import Image as PILImage

def display_images(image_collection, query_text, max_distance=None, debug=False):
    """
    Display images in a Streamlit app based on a query.

    Args:
        image_collection: The image collection object for querying.
        query_text (str): The text query for images.
        max_distance (float, optional): Maximum allowable distance for filtering.
        debug (bool, optional): Whether to print debug information.
    """
    results = image_collection.query(
        query_texts=[query_text],
        n_results=10,
        include=['uris', 'distances']
    )

    uris = results['uris'][0]
    distances = results['distances'][0]

    # Combine uris and distances, then sort by URI in ascending order
    sorted_results = sorted(zip(uris, distances), key=lambda x: x[0])

    # Filter and display images
    for uri, distance in sorted_results:
        if max_distance is None or distance <= max_distance:
            if debug:
                st.write(f"URI: {uri} - Distance: {distance}")
            try:
                img = PILImage.open(uri)
                st.image(img, width=300)
            except Exception as e:
                st.error(f"Error loading image {uri}: {e}")
        else:
            if debug:
                st.write(f"URI: {uri} - Distance: {distance} (Filtered out)")

            

def display_videos_streamlit(video_collection, query_text, max_distance=None, max_results=5, debug=False):
    """
    Display videos in a Streamlit app based on a query.

    Args:
        video_collection: The video collection object for querying.
        query_text (str): The text query for videos.
        max_distance (float, optional): Maximum allowable distance for filtering.
        max_results (int, optional): Maximum number of results to display.
        debug (bool, optional): Whether to print debug information.
    """
    # Deduplication set
    displayed_videos = set()

    # Query the video collection with the specified text
    results = video_collection.query(
        query_texts=[query_text],
        n_results=max_results,  # Adjust the number of results if needed
        include=['uris', 'distances', 'metadatas']
    )

    # Extract URIs, distances, and metadatas from the result
    uris = results['uris'][0]
    distances = results['distances'][0]
    metadatas = results['metadatas'][0]

    # Display the videos that meet the distance criteria
    for uri, distance, metadata in zip(uris, distances, metadatas):
        video_uri = metadata['video_uri']

        # Check if a max_distance filter is applied and the distance is within the allowed range
        if (max_distance is None or distance <= max_distance) and video_uri not in displayed_videos:
            if debug:
                st.write(f"URI: {uri} - Video URI: {video_uri} - Distance: {distance}")
            st.video(video_uri)  # Display video in Streamlit
            displayed_videos.add(video_uri)  # Add to the set to prevent duplication
        else:
            if debug:
                st.write(f"URI: {uri} - Video URI: {video_uri} - Distance: {distance} (Filtered out)")
                
                
def image_uris(image_collection,query_text, max_distance=None, max_results=5):
    results = image_collection.query(
        query_texts=[query_text],
        n_results=max_results,
        include=['uris', 'distances']
    )

    filtered_uris = []
    for uri, distance in zip(results['uris'][0], results['distances'][0]):
        if max_distance is None or distance <= max_distance:
            filtered_uris.append(uri)

    return filtered_uris

def text_uris(text_collection,query_text, max_distance=None, max_results=5):
    results = text_collection.query(
        query_texts=[query_text],
        n_results=max_results,
        include=['documents', 'distances']
    )

    filtered_texts = []
    for doc, distance in zip(results['documents'][0], results['distances'][0]):
        if max_distance is None or distance <= max_distance:
            filtered_texts.append(doc)

    return filtered_texts

def frame_uris(video_collection,query_text, max_distance=None, max_results=5):
    results = video_collection.query(
        query_texts=[query_text],
        n_results=max_results,
        include=['uris', 'distances']
    )

    filtered_uris = []
    seen_folders = set()

    for uri, distance in zip(results['uris'][0], results['distances'][0]):
        if max_distance is None or distance <= max_distance:
            folder = os.path.dirname(uri)
            if folder not in seen_folders:
                filtered_uris.append(uri)
                seen_folders.add(folder)

        if len(filtered_uris) == max_results:
            break

    return filtered_uris

def image_uris2(image_collection2,query_text, max_distance=None, max_results=5):
    results = image_collection2.query(
        query_texts=[query_text],
        n_results=max_results,
        include=['uris', 'distances']
    )

    filtered_uris = []
    for uri, distance in zip(results['uris'][0], results['distances'][0]):
        if max_distance is None or distance <= max_distance:
            filtered_uris.append(uri)

    return filtered_uris


def format_prompt_inputs(image_collection, text_collection, video_collection, user_query):
    frame_candidates = frame_uris(video_collection, user_query, max_distance=1.55)
    image_candidates = image_uris(image_collection, user_query, max_distance=1.5)
    texts = text_uris(text_collection, user_query, max_distance=1.3)

    inputs = {"query": user_query, "texts": texts}
    frame = frame_candidates[0] if frame_candidates else ""
    inputs["frame"] = frame

    if image_candidates:
        image = image_candidates[0]
        with PILImage.open(image) as img:
            img = img.resize((img.width // 6, img.height // 6))
            img = img.convert("L")
            with io.BytesIO() as output:
                img.save(output, format="JPEG", quality=60)
                compressed_image_data = output.getvalue()

        inputs["image_data_1"] = base64.b64encode(compressed_image_data).decode('utf-8')
    else:
        inputs["image_data_1"] = ""

    return inputs

import time  # To simulate delays during processing

def page_1():
    st.title("Page 1: Upload and Process PDF")

    uploaded_file = st.file_uploader("Upload a PDF file", type=["pdf"])
    if uploaded_file:
        pdf_path = f"/tmp/{uploaded_file.name}"
        with open(pdf_path, "wb") as f:
            f.write(uploaded_file.getbuffer())

        # Display a spinner while processing
        with st.spinner("Processing PDF... Please wait."):
            try:
                # Simulate processing stages with a delay (this is just an example)
                time.sleep(1)  # Simulate a step in the processing

                # Step 1: Process images, texts, and videos
                st.text("Extracting content from PDF...")
                image_collection, text_collection, video_collection = process_pdf(pdf_path)
                st.session_state.image_collection = image_collection
                st.session_state.text_collection = text_collection
                st.session_state.video_collection = video_collection

                # Simulate a delay for finalizing (if needed)
                time.sleep(1)  # Simulate final step

                st.success("PDF processed successfully! Collections saved to session state.")
            except Exception as e:
                st.error(f"Error processing PDF: {e}")

def page_2():
    st.title("Page 2: Query and Use Processed Collections")

    if "image_collection" in st.session_state and "text_collection" in st.session_state and "video_collection" in st.session_state:
        image_collection = st.session_state.image_collection
        text_collection = st.session_state.text_collection
        video_collection = st.session_state.video_collection
        st.success("Collections loaded successfully.")

        query = st.text_input("Enter your query", value="Example Query")
        if query:
            inputs = format_prompt_inputs(image_collection, text_collection, video_collection, query)
            texts = inputs["texts"]
            image_data_1 = inputs["image_data_1"]

            if image_data_1:
                image_data_1 = process_image_and_get_description(image_data_1)

            response = intiate_convo(query, image_data_1, texts)
            st.write("Response:", response)

            st.markdown("### Images")
            display_images(image_collection, query, max_distance=1.55, debug=True)

            st.markdown("### Videos")
            frame = inputs["frame"]
            if frame:
                video_path = f"StockVideos-CC0/{os.path.basename(frame).split('/')[0]}.mp4"
                if os.path.exists(video_path):
                    st.video(video_path)
                else:
                    st.write("No related videos found.")
    else:
        st.error("Collections not found in session state. Please process the PDF on Page 1.")

# --- Navigation ---

PAGES = {
    "Upload and Process PDF": page_1,
    "Query and Use Processed Collections": page_2
}

# Select page
selected_page = st.sidebar.selectbox("Choose a page", options=list(PAGES.keys()))

# Render selected page
PAGES[selected_page]()