NEXAS commited on
Commit
1131ca9
·
verified ·
1 Parent(s): 0ce361f

Update user.py

Browse files
Files changed (1) hide show
  1. user.py +156 -50
user.py CHANGED
@@ -17,55 +17,6 @@ def get_answer(query):
17
  #return response["result"]
18
  return response
19
 
20
- def home():
21
- st.header("Welcome")
22
- #st.set_page_config(layout='wide', page_title="Virtual Tutor")
23
- st.markdown("""
24
- <svg width="600" height="100">
25
- <text x="50%" y="50%" font-family="San serif" font-size="42px" fill="Black" text-anchor="middle" stroke="white"
26
- stroke-width="0.3" stroke-linejoin="round">Virtual Tutor - CHAT
27
- </text>
28
- </svg>
29
- """, unsafe_allow_html=True)
30
-
31
- if "messages" not in st.session_state:
32
- st.session_state.messages = [
33
- {"role": "assistant", "content": "Hi! How may I assist you today?"}
34
- ]
35
-
36
- st.markdown("""
37
- <style>
38
- .stChatInputContainer > div {
39
- background-color: #000000;
40
- }
41
- </style>
42
- """, unsafe_allow_html=True)
43
-
44
- for message in st.session_state.messages: # Display the prior chat messages
45
- with st.chat_message(message["role"]):
46
- st.write(message["content"])
47
-
48
- for i, msg in enumerate(memory_storage.messages):
49
- name = "user" if i % 2 == 0 else "assistant"
50
- st.chat_message(name).markdown(msg.content)
51
-
52
- if user_input := st.chat_input("User Input"):
53
- with st.chat_message("user"):
54
- st.markdown(user_input)
55
-
56
- with st.spinner("Generating Response..."):
57
- with st.chat_message("assistant"):
58
- response = get_answer(user_input)
59
- answer = response['result']
60
- st.markdown(answer)
61
- message = {"role": "assistant", "content": answer}
62
- message_u = {"role": "user", "content": user_input}
63
- st.session_state.messages.append(message_u)
64
- st.session_state.messages.append(message)
65
- display_images(user_input)
66
- display_videos_streamlit(user_input)
67
-
68
-
69
  def display_images(image_collection, query_text, max_distance=None, debug=False):
70
  """
71
  Display images in a Streamlit app based on a query.
@@ -138,4 +89,159 @@ def display_videos_streamlit(video_collection, query_text, max_distance=None, ma
138
  if debug:
139
  st.write(f"URI: {uri} - Video URI: {video_uri} - Distance: {distance} (Filtered out)")
140
 
141
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  #return response["result"]
18
  return response
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  def display_images(image_collection, query_text, max_distance=None, debug=False):
21
  """
22
  Display images in a Streamlit app based on a query.
 
89
  if debug:
90
  st.write(f"URI: {uri} - Video URI: {video_uri} - Distance: {distance} (Filtered out)")
91
 
92
+
93
+ def image_uris(image_collection,query_text, max_distance=None, max_results=5):
94
+ results = image_collection.query(
95
+ query_texts=[query_text],
96
+ n_results=max_results,
97
+ include=['uris', 'distances']
98
+ )
99
+
100
+ filtered_uris = []
101
+ for uri, distance in zip(results['uris'][0], results['distances'][0]):
102
+ if max_distance is None or distance <= max_distance:
103
+ filtered_uris.append(uri)
104
+
105
+ return filtered_uris
106
+
107
+ def text_uris(text_collection,query_text, max_distance=None, max_results=5):
108
+ results = text_collection.query(
109
+ query_texts=[query_text],
110
+ n_results=max_results,
111
+ include=['documents', 'distances']
112
+ )
113
+
114
+ filtered_texts = []
115
+ for doc, distance in zip(results['documents'][0], results['distances'][0]):
116
+ if max_distance is None or distance <= max_distance:
117
+ filtered_texts.append(doc)
118
+
119
+ return filtered_texts
120
+
121
+ def frame_uris(video_collection,query_text, max_distance=None, max_results=5):
122
+ results = video_collection.query(
123
+ query_texts=[query_text],
124
+ n_results=max_results,
125
+ include=['uris', 'distances']
126
+ )
127
+
128
+ filtered_uris = []
129
+ seen_folders = set()
130
+
131
+ for uri, distance in zip(results['uris'][0], results['distances'][0]):
132
+ if max_distance is None or distance <= max_distance:
133
+ folder = os.path.dirname(uri)
134
+ if folder not in seen_folders:
135
+ filtered_uris.append(uri)
136
+ seen_folders.add(folder)
137
+
138
+ if len(filtered_uris) == max_results:
139
+ break
140
+
141
+ return filtered_uris
142
+
143
+ def image_uris2(image_collection2,query_text, max_distance=None, max_results=5):
144
+ results = image_collection2.query(
145
+ query_texts=[query_text],
146
+ n_results=max_results,
147
+ include=['uris', 'distances']
148
+ )
149
+
150
+ filtered_uris = []
151
+ for uri, distance in zip(results['uris'][0], results['distances'][0]):
152
+ if max_distance is None or distance <= max_distance:
153
+ filtered_uris.append(uri)
154
+
155
+ return filtered_uris
156
+
157
+ def format_prompt_inputs(image_collection, video_collection, user_query):
158
+ # Get frame candidates from the video collection
159
+ frame_candidates = frame_uris(video_collection, user_query, max_distance=1.55)
160
+
161
+ # Get image candidates from the image collection
162
+ image_candidates = image_uris(image_collection, user_query, max_distance=1.5)
163
+
164
+ # Initialize the inputs dictionary with just the query
165
+ inputs = {"query": user_query}
166
+
167
+ # Add the frame if found
168
+ frame = frame_candidates[0] if frame_candidates else ""
169
+ inputs["frame"] = frame
170
+
171
+ # If image candidates exist, process the first image
172
+ if image_candidates:
173
+ image = image_candidates[0]
174
+ with PILImage.open(image) as img:
175
+ img = img.resize((img.width // 6, img.height // 6)) # Resize the image
176
+ img = img.convert("L") # Convert to grayscale
177
+ with io.BytesIO() as output:
178
+ img.save(output, format="JPEG", quality=60) # Save as JPEG with compression
179
+ compressed_image_data = output.getvalue()
180
+
181
+ # Encode the compressed image as base64
182
+ inputs["image_data_1"] = base64.b64encode(compressed_image_data).decode('utf-8')
183
+ else:
184
+ inputs["image_data_1"] = ""
185
+
186
+ return inputs
187
+
188
+
189
+ def home():
190
+ st.header("Welcome")
191
+ #st.set_page_config(layout='wide', page_title="Virtual Tutor")
192
+ st.markdown("""
193
+ <svg width="600" height="100">
194
+ <text x="50%" y="50%" font-family="San serif" font-size="42px" fill="Black" text-anchor="middle" stroke="white"
195
+ stroke-width="0.3" stroke-linejoin="round">Virtual Tutor - CHAT
196
+ </text>
197
+ </svg>
198
+ """, unsafe_allow_html=True)
199
+
200
+ if "messages" not in st.session_state:
201
+ st.session_state.messages = [
202
+ {"role": "assistant", "content": "Hi! How may I assist you today?"}
203
+ ]
204
+
205
+ st.markdown("""
206
+ <style>
207
+ .stChatInputContainer > div {
208
+ background-color: #000000;
209
+ }
210
+ </style>
211
+ """, unsafe_allow_html=True)
212
+
213
+ for message in st.session_state.messages: # Display the prior chat messages
214
+ with st.chat_message(message["role"]):
215
+ st.write(message["content"])
216
+
217
+ for i, msg in enumerate(memory_storage.messages):
218
+ name = "user" if i % 2 == 0 else "assistant"
219
+ st.chat_message(name).markdown(msg.content)
220
+
221
+ if user_input := st.chat_input("User Input"):
222
+ with st.chat_message("user"):
223
+ st.markdown(user_input)
224
+
225
+ with st.spinner("Generating Response..."):
226
+ with st.chat_message("assistant"):
227
+ response = get_answer(user_input)
228
+ answer = response['result']
229
+ st.markdown(answer)
230
+
231
+ message = {"role": "assistant", "content": answer}
232
+ message_u = {"role": "user", "content": user_input}
233
+ st.session_state.messages.append(message_u)
234
+ st.session_state.messages.append(message)
235
+ inputs = format_prompt_inputs(image_collection,video_collection, user_input)
236
+ st.markdown("### Images")
237
+ display_images(image_collection, query, max_distance=1.55, debug=False)
238
+ st.markdown("### Videos")
239
+ frame = inputs["frame"]
240
+ if frame:
241
+ directory_name = frame.split('/')[1]
242
+ video_path = f"videos_flattened/{directory_name}.mp4"
243
+ if os.path.exists(video_path):
244
+ st.video(video_path)
245
+ else:
246
+ st.write("No related videos found.")
247
+