NEXAS commited on
Commit
32bc45c
·
verified ·
1 Parent(s): 1131ca9

Update user.py

Browse files
Files changed (1) hide show
  1. user.py +55 -124
user.py CHANGED
@@ -1,31 +1,30 @@
1
- from utils.qa import chain
 
2
  import streamlit as st
 
 
3
  from langchain.memory import ConversationBufferWindowMemory
4
  from langchain_community.chat_message_histories import StreamlitChatMessageHistory
 
 
5
 
 
6
  path = "mm_vdb2"
7
  client = chromadb.PersistentClient(path=path)
8
  image_collection = client.get_collection(name="image")
9
  video_collection = client.get_collection(name='video_collection')
10
 
11
-
12
  memory_storage = StreamlitChatMessageHistory(key="chat_messages")
13
  memory = ConversationBufferWindowMemory(memory_key="chat_history", human_prefix="User", chat_memory=memory_storage, k=3)
14
 
 
15
  def get_answer(query):
16
  response = chain.invoke(query)
17
- #return response["result"]
18
- return response
19
 
 
20
  def display_images(image_collection, query_text, max_distance=None, debug=False):
21
- """
22
- Display images in a Streamlit app based on a query.
23
- Args:
24
- image_collection: The image collection object for querying.
25
- query_text (str): The text query for images.
26
- max_distance (float, optional): Maximum allowable distance for filtering.
27
- debug (bool, optional): Whether to print debug information.
28
- """
29
  results = image_collection.query(
30
  query_texts=[query_text],
31
  n_results=10,
@@ -35,160 +34,79 @@ def display_images(image_collection, query_text, max_distance=None, debug=False)
35
  uris = results['uris'][0]
36
  distances = results['distances'][0]
37
 
38
- # Combine uris and distances, then sort by URI in ascending order
39
  sorted_results = sorted(zip(uris, distances), key=lambda x: x[0])
40
 
41
- # Display images side by side, 3 images per row
42
- cols = st.columns(3) # Create 3 columns for the layout
43
 
44
  for i, (uri, distance) in enumerate(sorted_results):
45
  if max_distance is None or distance <= max_distance:
46
  try:
47
  img = PILImage.open(uri)
48
- with cols[i % 3]: # Use modulo to cycle through columns
49
- st.image(img, use_container_width = True)
50
  except Exception as e:
51
  st.error(f"Error loading image: {e}")
52
 
 
53
  def display_videos_streamlit(video_collection, query_text, max_distance=None, max_results=5, debug=False):
54
- """
55
- Display videos in a Streamlit app based on a query.
56
- Args:
57
- video_collection: The video collection object for querying.
58
- query_text (str): The text query for videos.
59
- max_distance (float, optional): Maximum allowable distance for filtering.
60
- max_results (int, optional): Maximum number of results to display.
61
- debug (bool, optional): Whether to print debug information.
62
- """
63
- # Deduplication set
64
  displayed_videos = set()
65
 
66
- # Query the video collection with the specified text
67
  results = video_collection.query(
68
  query_texts=[query_text],
69
- n_results=max_results, # Adjust the number of results if needed
70
  include=['uris', 'distances', 'metadatas']
71
  )
72
 
73
- # Extract URIs, distances, and metadatas from the result
74
  uris = results['uris'][0]
75
  distances = results['distances'][0]
76
  metadatas = results['metadatas'][0]
77
 
78
- # Display the videos that meet the distance criteria
79
  for uri, distance, metadata in zip(uris, distances, metadatas):
80
  video_uri = metadata['video_uri']
81
 
82
- # Check if a max_distance filter is applied and the distance is within the allowed range
83
  if (max_distance is None or distance <= max_distance) and video_uri not in displayed_videos:
84
  if debug:
85
  st.write(f"URI: {uri} - Video URI: {video_uri} - Distance: {distance}")
86
- st.video(video_uri) # Display video in Streamlit
87
- displayed_videos.add(video_uri) # Add to the set to prevent duplication
88
  else:
89
  if debug:
90
  st.write(f"URI: {uri} - Video URI: {video_uri} - Distance: {distance} (Filtered out)")
91
 
92
-
93
- def image_uris(image_collection,query_text, max_distance=None, max_results=5):
94
- results = image_collection.query(
95
- query_texts=[query_text],
96
- n_results=max_results,
97
- include=['uris', 'distances']
98
- )
99
-
100
- filtered_uris = []
101
- for uri, distance in zip(results['uris'][0], results['distances'][0]):
102
- if max_distance is None or distance <= max_distance:
103
- filtered_uris.append(uri)
104
-
105
- return filtered_uris
106
-
107
- def text_uris(text_collection,query_text, max_distance=None, max_results=5):
108
- results = text_collection.query(
109
- query_texts=[query_text],
110
- n_results=max_results,
111
- include=['documents', 'distances']
112
- )
113
-
114
- filtered_texts = []
115
- for doc, distance in zip(results['documents'][0], results['distances'][0]):
116
- if max_distance is None or distance <= max_distance:
117
- filtered_texts.append(doc)
118
-
119
- return filtered_texts
120
-
121
- def frame_uris(video_collection,query_text, max_distance=None, max_results=5):
122
- results = video_collection.query(
123
- query_texts=[query_text],
124
- n_results=max_results,
125
- include=['uris', 'distances']
126
- )
127
-
128
- filtered_uris = []
129
- seen_folders = set()
130
-
131
- for uri, distance in zip(results['uris'][0], results['distances'][0]):
132
- if max_distance is None or distance <= max_distance:
133
- folder = os.path.dirname(uri)
134
- if folder not in seen_folders:
135
- filtered_uris.append(uri)
136
- seen_folders.add(folder)
137
-
138
- if len(filtered_uris) == max_results:
139
- break
140
-
141
- return filtered_uris
142
-
143
- def image_uris2(image_collection2,query_text, max_distance=None, max_results=5):
144
- results = image_collection2.query(
145
- query_texts=[query_text],
146
- n_results=max_results,
147
- include=['uris', 'distances']
148
- )
149
-
150
- filtered_uris = []
151
- for uri, distance in zip(results['uris'][0], results['distances'][0]):
152
- if max_distance is None or distance <= max_distance:
153
- filtered_uris.append(uri)
154
-
155
- return filtered_uris
156
-
157
  def format_prompt_inputs(image_collection, video_collection, user_query):
158
- # Get frame candidates from the video collection
159
  frame_candidates = frame_uris(video_collection, user_query, max_distance=1.55)
160
-
161
- # Get image candidates from the image collection
162
  image_candidates = image_uris(image_collection, user_query, max_distance=1.5)
163
 
164
- # Initialize the inputs dictionary with just the query
165
  inputs = {"query": user_query}
166
 
167
- # Add the frame if found
168
  frame = frame_candidates[0] if frame_candidates else ""
169
  inputs["frame"] = frame
170
 
171
- # If image candidates exist, process the first image
172
  if image_candidates:
173
  image = image_candidates[0]
174
  with PILImage.open(image) as img:
175
- img = img.resize((img.width // 6, img.height // 6)) # Resize the image
176
- img = img.convert("L") # Convert to grayscale
177
  with io.BytesIO() as output:
178
- img.save(output, format="JPEG", quality=60) # Save as JPEG with compression
179
  compressed_image_data = output.getvalue()
180
 
181
- # Encode the compressed image as base64
182
  inputs["image_data_1"] = base64.b64encode(compressed_image_data).decode('utf-8')
183
  else:
184
  inputs["image_data_1"] = ""
185
 
186
  return inputs
187
 
188
-
189
  def home():
190
- st.header("Welcome")
191
- #st.set_page_config(layout='wide', page_title="Virtual Tutor")
 
 
 
 
 
192
  st.markdown("""
193
  <svg width="600" height="100">
194
  <text x="50%" y="50%" font-family="San serif" font-size="42px" fill="Black" text-anchor="middle" stroke="white"
@@ -197,11 +115,11 @@ def home():
197
  </svg>
198
  """, unsafe_allow_html=True)
199
 
 
200
  if "messages" not in st.session_state:
201
- st.session_state.messages = [
202
- {"role": "assistant", "content": "Hi! How may I assist you today?"}
203
- ]
204
 
 
205
  st.markdown("""
206
  <style>
207
  .stChatInputContainer > div {
@@ -210,38 +128,51 @@ def home():
210
  </style>
211
  """, unsafe_allow_html=True)
212
 
213
- for message in st.session_state.messages: # Display the prior chat messages
 
214
  with st.chat_message(message["role"]):
215
  st.write(message["content"])
216
 
 
217
  for i, msg in enumerate(memory_storage.messages):
218
  name = "user" if i % 2 == 0 else "assistant"
219
  st.chat_message(name).markdown(msg.content)
220
 
221
- if user_input := st.chat_input("User Input"):
 
222
  with st.chat_message("user"):
223
  st.markdown(user_input)
224
 
225
  with st.spinner("Generating Response..."):
226
  with st.chat_message("assistant"):
227
  response = get_answer(user_input)
228
- answer = response['result']
229
  st.markdown(answer)
230
-
 
231
  message = {"role": "assistant", "content": answer}
232
  message_u = {"role": "user", "content": user_input}
233
  st.session_state.messages.append(message_u)
234
  st.session_state.messages.append(message)
235
- inputs = format_prompt_inputs(image_collection,video_collection, user_input)
 
 
 
 
236
  st.markdown("### Images")
237
- display_images(image_collection, query, max_distance=1.55, debug=False)
 
 
238
  st.markdown("### Videos")
239
  frame = inputs["frame"]
240
  if frame:
241
- directory_name = frame.split('/')[1]
242
  video_path = f"videos_flattened/{directory_name}.mp4"
243
  if os.path.exists(video_path):
244
  st.video(video_path)
245
  else:
246
- st.write("No related videos found.")
247
-
 
 
 
 
1
+ import chromadb
2
+ from PIL import Image as PILImage
3
  import streamlit as st
4
+ import os
5
+ from utils.qa import chain
6
  from langchain.memory import ConversationBufferWindowMemory
7
  from langchain_community.chat_message_histories import StreamlitChatMessageHistory
8
+ import base64
9
+ import io
10
 
11
+ # Initialize Chromadb client
12
  path = "mm_vdb2"
13
  client = chromadb.PersistentClient(path=path)
14
  image_collection = client.get_collection(name="image")
15
  video_collection = client.get_collection(name='video_collection')
16
 
17
+ # Set up memory storage for the chat
18
  memory_storage = StreamlitChatMessageHistory(key="chat_messages")
19
  memory = ConversationBufferWindowMemory(memory_key="chat_history", human_prefix="User", chat_memory=memory_storage, k=3)
20
 
21
+ # Function to get an answer from the chain
22
  def get_answer(query):
23
  response = chain.invoke(query)
24
+ return response.get("result", "No result found.")
 
25
 
26
+ # Function to display images in the UI
27
  def display_images(image_collection, query_text, max_distance=None, debug=False):
 
 
 
 
 
 
 
 
28
  results = image_collection.query(
29
  query_texts=[query_text],
30
  n_results=10,
 
34
  uris = results['uris'][0]
35
  distances = results['distances'][0]
36
 
 
37
  sorted_results = sorted(zip(uris, distances), key=lambda x: x[0])
38
 
39
+ cols = st.columns(3)
 
40
 
41
  for i, (uri, distance) in enumerate(sorted_results):
42
  if max_distance is None or distance <= max_distance:
43
  try:
44
  img = PILImage.open(uri)
45
+ with cols[i % 3]:
46
+ st.image(img, use_container_width=True)
47
  except Exception as e:
48
  st.error(f"Error loading image: {e}")
49
 
50
+ # Function to display videos in the UI
51
  def display_videos_streamlit(video_collection, query_text, max_distance=None, max_results=5, debug=False):
 
 
 
 
 
 
 
 
 
 
52
  displayed_videos = set()
53
 
 
54
  results = video_collection.query(
55
  query_texts=[query_text],
56
+ n_results=max_results,
57
  include=['uris', 'distances', 'metadatas']
58
  )
59
 
 
60
  uris = results['uris'][0]
61
  distances = results['distances'][0]
62
  metadatas = results['metadatas'][0]
63
 
 
64
  for uri, distance, metadata in zip(uris, distances, metadatas):
65
  video_uri = metadata['video_uri']
66
 
 
67
  if (max_distance is None or distance <= max_distance) and video_uri not in displayed_videos:
68
  if debug:
69
  st.write(f"URI: {uri} - Video URI: {video_uri} - Distance: {distance}")
70
+ st.video(video_uri)
71
+ displayed_videos.add(video_uri)
72
  else:
73
  if debug:
74
  st.write(f"URI: {uri} - Video URI: {video_uri} - Distance: {distance} (Filtered out)")
75
 
76
+ # Function to format the inputs for image and video processing
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  def format_prompt_inputs(image_collection, video_collection, user_query):
 
78
  frame_candidates = frame_uris(video_collection, user_query, max_distance=1.55)
 
 
79
  image_candidates = image_uris(image_collection, user_query, max_distance=1.5)
80
 
 
81
  inputs = {"query": user_query}
82
 
 
83
  frame = frame_candidates[0] if frame_candidates else ""
84
  inputs["frame"] = frame
85
 
 
86
  if image_candidates:
87
  image = image_candidates[0]
88
  with PILImage.open(image) as img:
89
+ img = img.resize((img.width // 6, img.height // 6))
90
+ img = img.convert("L")
91
  with io.BytesIO() as output:
92
+ img.save(output, format="JPEG", quality=60)
93
  compressed_image_data = output.getvalue()
94
 
 
95
  inputs["image_data_1"] = base64.b64encode(compressed_image_data).decode('utf-8')
96
  else:
97
  inputs["image_data_1"] = ""
98
 
99
  return inputs
100
 
101
+ # Main function to initialize and run the UI
102
  def home():
103
+ # Set up the page layout
104
+ st.set_page_config(layout='wide', page_title="Virtual Tutor")
105
+
106
+ # Header
107
+ st.header("Welcome to Virtual Tutor - CHAT")
108
+
109
+ # SVG Banner for UI branding
110
  st.markdown("""
111
  <svg width="600" height="100">
112
  <text x="50%" y="50%" font-family="San serif" font-size="42px" fill="Black" text-anchor="middle" stroke="white"
 
115
  </svg>
116
  """, unsafe_allow_html=True)
117
 
118
+ # Initialize the chat session if not already initialized
119
  if "messages" not in st.session_state:
120
+ st.session_state.messages = [{"role": "assistant", "content": "Hi! How may I assist you today?"}]
 
 
121
 
122
+ # Styling for the chat input container
123
  st.markdown("""
124
  <style>
125
  .stChatInputContainer > div {
 
128
  </style>
129
  """, unsafe_allow_html=True)
130
 
131
+ # Display previous chat messages
132
+ for message in st.session_state.messages:
133
  with st.chat_message(message["role"]):
134
  st.write(message["content"])
135
 
136
+ # Display chat messages from memory
137
  for i, msg in enumerate(memory_storage.messages):
138
  name = "user" if i % 2 == 0 else "assistant"
139
  st.chat_message(name).markdown(msg.content)
140
 
141
+ # Handle user input and generate response
142
+ if user_input := st.chat_input("Enter your question here..."):
143
  with st.chat_message("user"):
144
  st.markdown(user_input)
145
 
146
  with st.spinner("Generating Response..."):
147
  with st.chat_message("assistant"):
148
  response = get_answer(user_input)
149
+ answer = response
150
  st.markdown(answer)
151
+
152
+ # Save user and assistant messages to session state
153
  message = {"role": "assistant", "content": answer}
154
  message_u = {"role": "user", "content": user_input}
155
  st.session_state.messages.append(message_u)
156
  st.session_state.messages.append(message)
157
+
158
+ # Process inputs for image/video
159
+ inputs = format_prompt_inputs(image_collection, video_collection, user_input)
160
+
161
+ # Display images
162
  st.markdown("### Images")
163
+ display_images(image_collection, user_input, max_distance=1.55, debug=False)
164
+
165
+ # Display videos based on frames
166
  st.markdown("### Videos")
167
  frame = inputs["frame"]
168
  if frame:
169
+ directory_name = frame.split('/')[1]
170
  video_path = f"videos_flattened/{directory_name}.mp4"
171
  if os.path.exists(video_path):
172
  st.video(video_path)
173
  else:
174
+ st.error("Video file not found.")
175
+
176
+ # Call the home function to run the app
177
+ if __name__ == "__main__":
178
+ home()