Spaces:
No application file
No application file
File size: 6,592 Bytes
32bc45c efc965e 32bc45c efc965e 32bc45c efc965e 32bc45c efc965e 32bc45c efc965e 32bc45c efc965e 32bc45c efc965e 32bc45c 0ce361f 32bc45c 0ce361f 32bc45c 0ce361f 32bc45c 0ce361f 32bc45c 0ce361f 32bc45c 0ce361f 32bc45c 1131ca9 32bc45c 1131ca9 32bc45c 1131ca9 32bc45c 1131ca9 32bc45c 1131ca9 32bc45c 1131ca9 32bc45c 1131ca9 32bc45c 1131ca9 32bc45c 1131ca9 32bc45c 1131ca9 32bc45c 1131ca9 32bc45c 1131ca9 32bc45c 1131ca9 32bc45c 1131ca9 32bc45c 1131ca9 32bc45c 1131ca9 32bc45c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 |
import chromadb
from PIL import Image as PILImage
import streamlit as st
import os
from utils.qa import chain
from langchain.memory import ConversationBufferWindowMemory
from langchain_community.chat_message_histories import StreamlitChatMessageHistory
import base64
import io
# Initialize Chromadb client
path = "mm_vdb2"
client = chromadb.PersistentClient(path=path)
image_collection = client.get_collection(name="image")
video_collection = client.get_collection(name='video_collection')
# Set up memory storage for the chat
memory_storage = StreamlitChatMessageHistory(key="chat_messages")
memory = ConversationBufferWindowMemory(memory_key="chat_history", human_prefix="User", chat_memory=memory_storage, k=3)
# Function to get an answer from the chain
def get_answer(query):
response = chain.invoke(query)
return response.get("result", "No result found.")
# Function to display images in the UI
def display_images(image_collection, query_text, max_distance=None, debug=False):
results = image_collection.query(
query_texts=[query_text],
n_results=10,
include=['uris', 'distances']
)
uris = results['uris'][0]
distances = results['distances'][0]
sorted_results = sorted(zip(uris, distances), key=lambda x: x[0])
cols = st.columns(3)
for i, (uri, distance) in enumerate(sorted_results):
if max_distance is None or distance <= max_distance:
try:
img = PILImage.open(uri)
with cols[i % 3]:
st.image(img, use_container_width=True)
except Exception as e:
st.error(f"Error loading image: {e}")
# Function to display videos in the UI
def display_videos_streamlit(video_collection, query_text, max_distance=None, max_results=5, debug=False):
displayed_videos = set()
results = video_collection.query(
query_texts=[query_text],
n_results=max_results,
include=['uris', 'distances', 'metadatas']
)
uris = results['uris'][0]
distances = results['distances'][0]
metadatas = results['metadatas'][0]
for uri, distance, metadata in zip(uris, distances, metadatas):
video_uri = metadata['video_uri']
if (max_distance is None or distance <= max_distance) and video_uri not in displayed_videos:
if debug:
st.write(f"URI: {uri} - Video URI: {video_uri} - Distance: {distance}")
st.video(video_uri)
displayed_videos.add(video_uri)
else:
if debug:
st.write(f"URI: {uri} - Video URI: {video_uri} - Distance: {distance} (Filtered out)")
# Function to format the inputs for image and video processing
def format_prompt_inputs(image_collection, video_collection, user_query):
frame_candidates = frame_uris(video_collection, user_query, max_distance=1.55)
image_candidates = image_uris(image_collection, user_query, max_distance=1.5)
inputs = {"query": user_query}
frame = frame_candidates[0] if frame_candidates else ""
inputs["frame"] = frame
if image_candidates:
image = image_candidates[0]
with PILImage.open(image) as img:
img = img.resize((img.width // 6, img.height // 6))
img = img.convert("L")
with io.BytesIO() as output:
img.save(output, format="JPEG", quality=60)
compressed_image_data = output.getvalue()
inputs["image_data_1"] = base64.b64encode(compressed_image_data).decode('utf-8')
else:
inputs["image_data_1"] = ""
return inputs
# Main function to initialize and run the UI
def home():
# Set up the page layout
st.set_page_config(layout='wide', page_title="Virtual Tutor")
# Header
st.header("Welcome to Virtual Tutor - CHAT")
# SVG Banner for UI branding
st.markdown("""
<svg width="600" height="100">
<text x="50%" y="50%" font-family="San serif" font-size="42px" fill="Black" text-anchor="middle" stroke="white"
stroke-width="0.3" stroke-linejoin="round">Virtual Tutor - CHAT
</text>
</svg>
""", unsafe_allow_html=True)
# Initialize the chat session if not already initialized
if "messages" not in st.session_state:
st.session_state.messages = [{"role": "assistant", "content": "Hi! How may I assist you today?"}]
# Styling for the chat input container
st.markdown("""
<style>
.stChatInputContainer > div {
background-color: #000000;
}
</style>
""", unsafe_allow_html=True)
# Display previous chat messages
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.write(message["content"])
# Display chat messages from memory
for i, msg in enumerate(memory_storage.messages):
name = "user" if i % 2 == 0 else "assistant"
st.chat_message(name).markdown(msg.content)
# Handle user input and generate response
if user_input := st.chat_input("Enter your question here..."):
with st.chat_message("user"):
st.markdown(user_input)
with st.spinner("Generating Response..."):
with st.chat_message("assistant"):
response = get_answer(user_input)
answer = response
st.markdown(answer)
# Save user and assistant messages to session state
message = {"role": "assistant", "content": answer}
message_u = {"role": "user", "content": user_input}
st.session_state.messages.append(message_u)
st.session_state.messages.append(message)
# Process inputs for image/video
inputs = format_prompt_inputs(image_collection, video_collection, user_input)
# Display images
st.markdown("### Images")
display_images(image_collection, user_input, max_distance=1.55, debug=False)
# Display videos based on frames
st.markdown("### Videos")
frame = inputs["frame"]
if frame:
directory_name = frame.split('/')[1]
video_path = f"videos_flattened/{directory_name}.mp4"
if os.path.exists(video_path):
st.video(video_path)
else:
st.error("Video file not found.")
# Call the home function to run the app
if __name__ == "__main__":
home()
|