File size: 6,592 Bytes
32bc45c
 
efc965e
32bc45c
 
efc965e
 
32bc45c
 
efc965e
32bc45c
efc965e
 
 
 
 
32bc45c
efc965e
 
 
32bc45c
efc965e
 
32bc45c
efc965e
32bc45c
0ce361f
 
 
 
 
 
 
 
 
 
 
 
32bc45c
0ce361f
 
 
 
 
32bc45c
 
0ce361f
 
 
32bc45c
0ce361f
 
 
 
 
32bc45c
0ce361f
 
 
 
 
 
 
 
 
 
 
 
 
32bc45c
 
0ce361f
 
 
 
32bc45c
1131ca9
 
 
 
 
 
 
 
 
 
 
 
32bc45c
 
1131ca9
32bc45c
1131ca9
 
 
 
 
 
 
 
32bc45c
1131ca9
32bc45c
 
 
 
 
 
 
1131ca9
 
 
 
 
 
 
 
32bc45c
1131ca9
32bc45c
1131ca9
32bc45c
1131ca9
 
 
 
 
 
 
 
32bc45c
 
1131ca9
 
 
32bc45c
1131ca9
 
 
 
32bc45c
 
1131ca9
 
 
 
 
 
32bc45c
1131ca9
32bc45c
 
1131ca9
 
 
 
32bc45c
 
 
 
 
1131ca9
32bc45c
 
 
1131ca9
 
 
32bc45c
1131ca9
 
 
 
32bc45c
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
import chromadb
from PIL import Image as PILImage
import streamlit as st
import os
from utils.qa import chain
from langchain.memory import ConversationBufferWindowMemory
from langchain_community.chat_message_histories import StreamlitChatMessageHistory
import base64
import io

# Initialize Chromadb client
path = "mm_vdb2"
client = chromadb.PersistentClient(path=path)
image_collection = client.get_collection(name="image")
video_collection = client.get_collection(name='video_collection')

# Set up memory storage for the chat
memory_storage = StreamlitChatMessageHistory(key="chat_messages")
memory = ConversationBufferWindowMemory(memory_key="chat_history", human_prefix="User", chat_memory=memory_storage, k=3)

# Function to get an answer from the chain
def get_answer(query):
    response = chain.invoke(query)
    return response.get("result", "No result found.")

# Function to display images in the UI
def display_images(image_collection, query_text, max_distance=None, debug=False):
    results = image_collection.query(
        query_texts=[query_text],
        n_results=10,
        include=['uris', 'distances']
    )

    uris = results['uris'][0]
    distances = results['distances'][0]

    sorted_results = sorted(zip(uris, distances), key=lambda x: x[0])

    cols = st.columns(3)

    for i, (uri, distance) in enumerate(sorted_results):
        if max_distance is None or distance <= max_distance:
            try:
                img = PILImage.open(uri)
                with cols[i % 3]:
                    st.image(img, use_container_width=True)
            except Exception as e:
                st.error(f"Error loading image: {e}")

# Function to display videos in the UI
def display_videos_streamlit(video_collection, query_text, max_distance=None, max_results=5, debug=False):
    displayed_videos = set()

    results = video_collection.query(
        query_texts=[query_text],
        n_results=max_results,
        include=['uris', 'distances', 'metadatas']
    )

    uris = results['uris'][0]
    distances = results['distances'][0]
    metadatas = results['metadatas'][0]

    for uri, distance, metadata in zip(uris, distances, metadatas):
        video_uri = metadata['video_uri']

        if (max_distance is None or distance <= max_distance) and video_uri not in displayed_videos:
            if debug:
                st.write(f"URI: {uri} - Video URI: {video_uri} - Distance: {distance}")
            st.video(video_uri)
            displayed_videos.add(video_uri)
        else:
            if debug:
                st.write(f"URI: {uri} - Video URI: {video_uri} - Distance: {distance} (Filtered out)")

# Function to format the inputs for image and video processing
def format_prompt_inputs(image_collection, video_collection, user_query):
    frame_candidates = frame_uris(video_collection, user_query, max_distance=1.55)
    image_candidates = image_uris(image_collection, user_query, max_distance=1.5)

    inputs = {"query": user_query}

    frame = frame_candidates[0] if frame_candidates else ""
    inputs["frame"] = frame

    if image_candidates:
        image = image_candidates[0]
        with PILImage.open(image) as img:
            img = img.resize((img.width // 6, img.height // 6))
            img = img.convert("L")
            with io.BytesIO() as output:
                img.save(output, format="JPEG", quality=60)
                compressed_image_data = output.getvalue()

        inputs["image_data_1"] = base64.b64encode(compressed_image_data).decode('utf-8')
    else:
        inputs["image_data_1"] = ""

    return inputs

# Main function to initialize and run the UI
def home():
    # Set up the page layout
    st.set_page_config(layout='wide', page_title="Virtual Tutor")

    # Header
    st.header("Welcome to Virtual Tutor - CHAT")

    # SVG Banner for UI branding
    st.markdown("""
        <svg width="600" height="100">
            <text x="50%" y="50%" font-family="San serif" font-size="42px" fill="Black" text-anchor="middle" stroke="white"
             stroke-width="0.3" stroke-linejoin="round">Virtual Tutor - CHAT
            </text>
        </svg>
    """, unsafe_allow_html=True)

    # Initialize the chat session if not already initialized
    if "messages" not in st.session_state:
        st.session_state.messages = [{"role": "assistant", "content": "Hi! How may I assist you today?"}]

    # Styling for the chat input container
    st.markdown("""
        <style> 
        .stChatInputContainer > div {
        background-color: #000000;
        }
        </style>
        """, unsafe_allow_html=True)

    # Display previous chat messages
    for message in st.session_state.messages:
        with st.chat_message(message["role"]):
            st.write(message["content"])

    # Display chat messages from memory
    for i, msg in enumerate(memory_storage.messages):
        name = "user" if i % 2 == 0 else "assistant"
        st.chat_message(name).markdown(msg.content)

    # Handle user input and generate response
    if user_input := st.chat_input("Enter your question here..."):
        with st.chat_message("user"):
            st.markdown(user_input)

        with st.spinner("Generating Response..."):
            with st.chat_message("assistant"):
                response = get_answer(user_input)
                answer = response
                st.markdown(answer)

                # Save user and assistant messages to session state
                message = {"role": "assistant", "content": answer}
                message_u = {"role": "user", "content": user_input}
                st.session_state.messages.append(message_u)
                st.session_state.messages.append(message)

                # Process inputs for image/video
                inputs = format_prompt_inputs(image_collection, video_collection, user_input)
                
                # Display images
                st.markdown("### Images")
                display_images(image_collection, user_input, max_distance=1.55, debug=False)
                
                # Display videos based on frames
                st.markdown("### Videos")
                frame = inputs["frame"]
                if frame:
                    directory_name = frame.split('/')[1]
                    video_path = f"videos_flattened/{directory_name}.mp4"
                    if os.path.exists(video_path):
                        st.video(video_path)
                    else:
                        st.error("Video file not found.")

# Call the home function to run the app
if __name__ == "__main__":
    home()