File size: 5,579 Bytes
efc965e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0ce361f
 
efc965e
0ce361f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
from utils.qa import chain
import streamlit as st
from langchain.memory import ConversationBufferWindowMemory
from langchain_community.chat_message_histories import StreamlitChatMessageHistory

path = "mm_vdb2"
client = chromadb.PersistentClient(path=path)
image_collection = client.get_collection(name="image")
video_collection = client.get_collection(name='video_collection')


memory_storage = StreamlitChatMessageHistory(key="chat_messages")
memory = ConversationBufferWindowMemory(memory_key="chat_history", human_prefix="User", chat_memory=memory_storage, k=3)

def get_answer(query):
    response = chain.invoke(query)
    #return response["result"]
    return response

def home():
    st.header("Welcome")
    #st.set_page_config(layout='wide', page_title="Virtual Tutor")
    st.markdown("""
        <svg width="600" height="100">
            <text x="50%" y="50%" font-family="San serif" font-size="42px" fill="Black" text-anchor="middle" stroke="white"
             stroke-width="0.3" stroke-linejoin="round">Virtual Tutor - CHAT
            </text>
        </svg>
    """, unsafe_allow_html=True)

    if "messages" not in st.session_state:
        st.session_state.messages = [
            {"role": "assistant", "content": "Hi! How may I assist you today?"}
        ]

    st.markdown("""
        <style> 
        .stChatInputContainer > div {
        background-color: #000000;
        }
        </style>
        """, unsafe_allow_html=True)

    for message in st.session_state.messages: # Display the prior chat messages
        with st.chat_message(message["role"]):
            st.write(message["content"])

    for i, msg in enumerate(memory_storage.messages):
        name = "user" if i % 2 == 0 else "assistant"
        st.chat_message(name).markdown(msg.content)

    if user_input := st.chat_input("User Input"):
        with st.chat_message("user"):
            st.markdown(user_input)

        with st.spinner("Generating Response..."):
            with st.chat_message("assistant"):
                response = get_answer(user_input)
                answer = response['result']
                st.markdown(answer)
                message = {"role": "assistant", "content": answer}
                message_u = {"role": "user", "content": user_input}
                st.session_state.messages.append(message_u)
                st.session_state.messages.append(message)
                display_images(user_input)
                display_videos_streamlit(user_input)


def display_images(image_collection, query_text, max_distance=None, debug=False):
    """
    Display images in a Streamlit app based on a query.
    Args:
        image_collection: The image collection object for querying.
        query_text (str): The text query for images.
        max_distance (float, optional): Maximum allowable distance for filtering.
        debug (bool, optional): Whether to print debug information.
    """
    results = image_collection.query(
        query_texts=[query_text],
        n_results=10,
        include=['uris', 'distances']
    )

    uris = results['uris'][0]
    distances = results['distances'][0]

    # Combine uris and distances, then sort by URI in ascending order
    sorted_results = sorted(zip(uris, distances), key=lambda x: x[0])

    # Display images side by side, 3 images per row
    cols = st.columns(3)  # Create 3 columns for the layout

    for i, (uri, distance) in enumerate(sorted_results):
        if max_distance is None or distance <= max_distance:
            try:
                img = PILImage.open(uri)
                with cols[i % 3]:  # Use modulo to cycle through columns
                    st.image(img, use_container_width = True)
            except Exception as e:
                st.error(f"Error loading image: {e}")

def display_videos_streamlit(video_collection, query_text, max_distance=None, max_results=5, debug=False):
    """
    Display videos in a Streamlit app based on a query.
    Args:
        video_collection: The video collection object for querying.
        query_text (str): The text query for videos.
        max_distance (float, optional): Maximum allowable distance for filtering.
        max_results (int, optional): Maximum number of results to display.
        debug (bool, optional): Whether to print debug information.
    """
    # Deduplication set
    displayed_videos = set()

    # Query the video collection with the specified text
    results = video_collection.query(
        query_texts=[query_text],
        n_results=max_results,  # Adjust the number of results if needed
        include=['uris', 'distances', 'metadatas']
    )

    # Extract URIs, distances, and metadatas from the result
    uris = results['uris'][0]
    distances = results['distances'][0]
    metadatas = results['metadatas'][0]

    # Display the videos that meet the distance criteria
    for uri, distance, metadata in zip(uris, distances, metadatas):
        video_uri = metadata['video_uri']

        # Check if a max_distance filter is applied and the distance is within the allowed range
        if (max_distance is None or distance <= max_distance) and video_uri not in displayed_videos:
            if debug:
                st.write(f"URI: {uri} - Video URI: {video_uri} - Distance: {distance}")
            st.video(video_uri)  # Display video in Streamlit
            displayed_videos.add(video_uri)  # Add to the set to prevent duplication
        else:
            if debug:
                st.write(f"URI: {uri} - Video URI: {video_uri} - Distance: {distance} (Filtered out)")