File size: 9,924 Bytes
ae74f7a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
03d82bf
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
import os
import streamlit as st
from PIL import Image as PILImage
from PIL import Image as pilImage
import base64
import io
import chromadb
from initate import process_pdf
from utils.llm_ag import intiate_convo
from utils.doi import process_image_and_get_description

path = "mm_vdb2"
client = chromadb.PersistentClient(path=path)
import streamlit as st
from PIL import Image as PILImage

def display_images(image_collection, query_text, max_distance=None, debug=False):
    """
    Display images in a Streamlit app based on a query.

    Args:
        image_collection: The image collection object for querying.
        query_text (str): The text query for images.
        max_distance (float, optional): Maximum allowable distance for filtering.
        debug (bool, optional): Whether to print debug information.
    """
    results = image_collection.query(
        query_texts=[query_text],
        n_results=10,
        include=['uris', 'distances']
    )

    uris = results['uris'][0]
    distances = results['distances'][0]

    # Combine uris and distances, then sort by URI in ascending order
    sorted_results = sorted(zip(uris, distances), key=lambda x: x[0])

    # Filter and display images
    for uri, distance in sorted_results:
        if max_distance is None or distance <= max_distance:
            if debug:
                st.write(f"URI: {uri} - Distance: {distance}")
            try:
                img = PILImage.open(uri)
                st.image(img, width=300)
            except Exception as e:
                st.error(f"Error loading image {uri}: {e}")
        else:
            if debug:
                st.write(f"URI: {uri} - Distance: {distance} (Filtered out)")

            

def display_videos_streamlit(video_collection, query_text, max_distance=None, max_results=5, debug=False):
    """
    Display videos in a Streamlit app based on a query.

    Args:
        video_collection: The video collection object for querying.
        query_text (str): The text query for videos.
        max_distance (float, optional): Maximum allowable distance for filtering.
        max_results (int, optional): Maximum number of results to display.
        debug (bool, optional): Whether to print debug information.
    """
    # Deduplication set
    displayed_videos = set()

    # Query the video collection with the specified text
    results = video_collection.query(
        query_texts=[query_text],
        n_results=max_results,  # Adjust the number of results if needed
        include=['uris', 'distances', 'metadatas']
    )

    # Extract URIs, distances, and metadatas from the result
    uris = results['uris'][0]
    distances = results['distances'][0]
    metadatas = results['metadatas'][0]

    # Display the videos that meet the distance criteria
    for uri, distance, metadata in zip(uris, distances, metadatas):
        video_uri = metadata['video_uri']

        # Check if a max_distance filter is applied and the distance is within the allowed range
        if (max_distance is None or distance <= max_distance) and video_uri not in displayed_videos:
            if debug:
                st.write(f"URI: {uri} - Video URI: {video_uri} - Distance: {distance}")
            st.video(video_uri)  # Display video in Streamlit
            displayed_videos.add(video_uri)  # Add to the set to prevent duplication
        else:
            if debug:
                st.write(f"URI: {uri} - Video URI: {video_uri} - Distance: {distance} (Filtered out)")
                
                
def image_uris(image_collection,query_text, max_distance=None, max_results=5):
    results = image_collection.query(
        query_texts=[query_text],
        n_results=max_results,
        include=['uris', 'distances']
    )

    filtered_uris = []
    for uri, distance in zip(results['uris'][0], results['distances'][0]):
        if max_distance is None or distance <= max_distance:
            filtered_uris.append(uri)

    return filtered_uris

def text_uris(text_collection,query_text, max_distance=None, max_results=5):
    results = text_collection.query(
        query_texts=[query_text],
        n_results=max_results,
        include=['documents', 'distances']
    )

    filtered_texts = []
    for doc, distance in zip(results['documents'][0], results['distances'][0]):
        if max_distance is None or distance <= max_distance:
            filtered_texts.append(doc)

    return filtered_texts

def frame_uris(video_collection,query_text, max_distance=None, max_results=5):
    results = video_collection.query(
        query_texts=[query_text],
        n_results=max_results,
        include=['uris', 'distances']
    )

    filtered_uris = []
    seen_folders = set()

    for uri, distance in zip(results['uris'][0], results['distances'][0]):
        if max_distance is None or distance <= max_distance:
            folder = os.path.dirname(uri)
            if folder not in seen_folders:
                filtered_uris.append(uri)
                seen_folders.add(folder)

        if len(filtered_uris) == max_results:
            break

    return filtered_uris

def image_uris2(image_collection2,query_text, max_distance=None, max_results=5):
    results = image_collection2.query(
        query_texts=[query_text],
        n_results=max_results,
        include=['uris', 'distances']
    )

    filtered_uris = []
    for uri, distance in zip(results['uris'][0], results['distances'][0]):
        if max_distance is None or distance <= max_distance:
            filtered_uris.append(uri)

    return filtered_uris


def format_prompt_inputs(image_collection, text_collection, video_collection, user_query):
    frame_candidates = frame_uris(video_collection, user_query, max_distance=1.55)
    image_candidates = image_uris(image_collection, user_query, max_distance=1.5)
    texts = text_uris(text_collection, user_query, max_distance=1.3)

    inputs = {"query": user_query, "texts": texts}
    frame = frame_candidates[0] if frame_candidates else ""
    inputs["frame"] = frame

    if image_candidates:
        image = image_candidates[0]
        with PILImage.open(image) as img:
            img = img.resize((img.width // 6, img.height // 6))
            img = img.convert("L")
            with io.BytesIO() as output:
                img.save(output, format="JPEG", quality=60)
                compressed_image_data = output.getvalue()

        inputs["image_data_1"] = base64.b64encode(compressed_image_data).decode('utf-8')
    else:
        inputs["image_data_1"] = ""

    return inputs
    
def page_1():
    st.title("Page 1: Upload and Process PDF")

    # File uploader for PDF
    uploaded_file = st.file_uploader("Upload a PDF file", type=["pdf"])
    
    # Button to trigger processing
    if uploaded_file and st.button("Process PDF"):
        pdf_path = f"/tmp/{uploaded_file.name}"
        with open(pdf_path, "wb") as f:
            f.write(uploaded_file.getbuffer())

        # Progress bar
        progress_bar = st.progress(0)
        status_text = st.empty()

        try:
            progress_bar.progress(10)
            status_text.text("Initializing processing...")

            # Simulating progress during processing
            for progress in range(10, 100, 30):
                st.time.sleep(0.5)  # Simulate processing delay
                progress_bar.progress(progress)
                status_text.text(f"Processing... {progress}%")

            # Process the PDF and save collections to session state
            image_collection, text_collection, video_collection = process_pdf(pdf_path)
            st.session_state.image_collection = image_collection
            st.session_state.text_collection = text_collection
            st.session_state.video_collection = video_collection

            progress_bar.progress(100)
            status_text.text("Processing completed successfully!")
            st.success("PDF processed successfully! Collections saved to session state.")
        except Exception as e:
            progress_bar.progress(0)
            status_text.text("")
            st.error(f"Error processing PDF: {e}")


def page_2():
    st.title("Page 2: Query and Use Processed Collections")

    if "image_collection" in st.session_state and "text_collection" in st.session_state and "video_collection" in st.session_state:
        image_collection = st.session_state.image_collection
        text_collection = st.session_state.text_collection
        video_collection = st.session_state.video_collection
        st.success("Collections loaded successfully.")

        query = st.text_input("Enter your query", value="Example Query")
        if query:
            inputs = format_prompt_inputs(image_collection, text_collection, video_collection, query)
            texts = inputs["texts"]
            image_data_1 = inputs["image_data_1"]

            if image_data_1:
                image_data_1 = process_image_and_get_description(image_data_1)

            response = intiate_convo(query, image_data_1, texts)
            st.write("Response:", response)

            st.markdown("### Images")
            display_images(image_collection, query, max_distance=1.55, debug=True)

            st.markdown("### Videos")
            frame = inputs["frame"]
            if frame:
                video_path = f"StockVideos-CC0/{os.path.basename(frame).split('/')[0]}.mp4"
                if os.path.exists(video_path):
                    st.video(video_path)
                else:
                    st.write("No related videos found.")
    else:
        st.error("Collections not found in session state. Please process the PDF on Page 1.")

# --- Navigation ---

PAGES = {
    "Upload and Process PDF": page_1,
    "Query and Use Processed Collections": page_2
}

# Select page
selected_page = st.sidebar.selectbox("Choose a page", options=list(PAGES.keys()))

# Render selected page
PAGES[selected_page]()