awacke1 commited on
Commit
642b060
Β·
verified Β·
1 Parent(s): 7932bd1

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +414 -0
app.py ADDED
@@ -0,0 +1,414 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import numpy as np
4
+ from sentence_transformers import SentenceTransformer
5
+ from sklearn.metrics.pairwise import cosine_similarity
6
+ import torch
7
+ import json
8
+ import os
9
+ import glob
10
+ from pathlib import Path
11
+ from datetime import datetime
12
+ import edge_tts
13
+ import asyncio
14
+ import base64
15
+ import requests
16
+ from collections import defaultdict
17
+ from audio_recorder_streamlit import audio_recorder
18
+ import streamlit.components.v1 as components
19
+ import re
20
+ from urllib.parse import quote
21
+ from xml.etree import ElementTree as ET
22
+
23
+ # Initialize session state
24
+ if 'search_history' not in st.session_state:
25
+ st.session_state['search_history'] = []
26
+ if 'last_voice_input' not in st.session_state:
27
+ st.session_state['last_voice_input'] = ""
28
+ if 'transcript_history' not in st.session_state:
29
+ st.session_state['transcript_history'] = []
30
+ if 'should_rerun' not in st.session_state:
31
+ st.session_state['should_rerun'] = False
32
+ if 'search_columns' not in st.session_state:
33
+ st.session_state['search_columns'] = []
34
+ if 'initial_search_done' not in st.session_state:
35
+ st.session_state['initial_search_done'] = False
36
+ if 'tts_voice' not in st.session_state:
37
+ st.session_state['tts_voice'] = "en-US-AriaNeural"
38
+ if 'arxiv_last_query' not in st.session_state:
39
+ st.session_state['arxiv_last_query'] = ""
40
+ if 'old_val' not in st.session_state:
41
+ st.session_state['old_val'] = None
42
+
43
+ def highlight_text(text, query):
44
+ """Highlight case-insensitive occurrences of query in text with bold formatting."""
45
+ if not query:
46
+ return text
47
+ pattern = re.compile(re.escape(query), re.IGNORECASE)
48
+ return pattern.sub(lambda m: f"**{m.group(0)}**", text)
49
+
50
+ class VideoSearch:
51
+ def __init__(self):
52
+ self.text_model = SentenceTransformer('all-MiniLM-L6-v2')
53
+ self.load_dataset()
54
+
55
+ def fetch_dataset_rows(self):
56
+ """Fetch dataset from Hugging Face API"""
57
+ try:
58
+ url = "https://datasets-server.huggingface.co/first-rows?dataset=omegalabsinc%2Fomega-multimodal&config=default&split=train"
59
+ response = requests.get(url, timeout=30)
60
+ if response.status_code == 200:
61
+ data = response.json()
62
+ if 'rows' in data:
63
+ processed_rows = []
64
+ for row_data in data['rows']:
65
+ row = row_data.get('row', row_data)
66
+ for key in row:
67
+ if any(term in key.lower() for term in ['embed', 'vector', 'encoding']):
68
+ if isinstance(row[key], str):
69
+ try:
70
+ row[key] = [float(x.strip()) for x in row[key].strip('[]').split(',') if x.strip()]
71
+ except:
72
+ continue
73
+ processed_rows.append(row)
74
+
75
+ df = pd.DataFrame(processed_rows)
76
+ st.session_state['search_columns'] = [col for col in df.columns
77
+ if col not in ['video_embed', 'description_embed', 'audio_embed']]
78
+ return df
79
+ return self.load_example_data()
80
+ except:
81
+ return self.load_example_data()
82
+
83
+ def prepare_features(self):
84
+ """Prepare embeddings with adaptive field detection"""
85
+ try:
86
+ embed_cols = [col for col in self.dataset.columns
87
+ if any(term in col.lower() for term in ['embed', 'vector', 'encoding'])]
88
+
89
+ embeddings = {}
90
+ for col in embed_cols:
91
+ try:
92
+ data = []
93
+ for row in self.dataset[col]:
94
+ if isinstance(row, str):
95
+ values = [float(x.strip()) for x in row.strip('[]').split(',') if x.strip()]
96
+ elif isinstance(row, list):
97
+ values = row
98
+ else:
99
+ continue
100
+ data.append(values)
101
+
102
+ if data:
103
+ embeddings[col] = np.array(data)
104
+ except:
105
+ continue
106
+
107
+ if 'video_embed' in embeddings:
108
+ self.video_embeds = embeddings['video_embed']
109
+ else:
110
+ self.video_embeds = next(iter(embeddings.values()))
111
+
112
+ if 'description_embed' in embeddings:
113
+ self.text_embeds = embeddings['description_embed']
114
+ else:
115
+ self.text_embeds = self.video_embeds
116
+
117
+ except:
118
+ # Fallback to random embeddings
119
+ num_rows = len(self.dataset)
120
+ self.video_embeds = np.random.randn(num_rows, 384)
121
+ self.text_embeds = np.random.randn(num_rows, 384)
122
+
123
+ def load_example_data(self):
124
+ """Load example data as fallback"""
125
+ example_data = [
126
+ {
127
+ "video_id": "cd21da96-fcca-4c94-a60f-0b1e4e1e29fc",
128
+ "youtube_id": "IO-vwtyicn4",
129
+ "description": "This video shows a close-up of an ancient text carved into a surface.",
130
+ "views": 45489,
131
+ "start_time": 1452,
132
+ "end_time": 1458,
133
+ "video_embed": [0.014160037972033024, -0.003111184574663639, -0.016604168340563774],
134
+ "description_embed": [-0.05835828185081482, 0.02589797042310238, 0.11952091753482819]
135
+ }
136
+ ]
137
+ return pd.DataFrame(example_data)
138
+
139
+ def load_dataset(self):
140
+ self.dataset = self.fetch_dataset_rows()
141
+ self.prepare_features()
142
+
143
+ def search(self, query, column=None, top_k=20):
144
+ # Semantic search
145
+ query_embedding = self.text_model.encode([query])[0]
146
+ video_sims = cosine_similarity([query_embedding], self.video_embeds)[0]
147
+ text_sims = cosine_similarity([query_embedding], self.text_embeds)[0]
148
+ combined_sims = 0.5 * video_sims + 0.5 * text_sims
149
+
150
+ # If a column is selected (not All Fields), strictly filter by textual match
151
+ if column and column in self.dataset.columns and column != "All Fields":
152
+ mask = self.dataset[column].astype(str).str.contains(query, case=False, na=False)
153
+ # Only keep rows that contain the query text in the selected column
154
+ combined_sims = combined_sims[mask]
155
+ filtered_dataset = self.dataset[mask].copy()
156
+ else:
157
+ filtered_dataset = self.dataset.copy()
158
+
159
+ # Get top results
160
+ top_k = min(top_k, len(combined_sims))
161
+ if top_k == 0:
162
+ return []
163
+ top_indices = np.argsort(combined_sims)[-top_k:][::-1]
164
+
165
+ results = []
166
+ filtered_dataset = filtered_dataset.iloc[top_indices]
167
+ filtered_sims = combined_sims[top_indices]
168
+ for idx, row in zip(top_indices, filtered_dataset.itertuples()):
169
+ result = {'relevance_score': float(filtered_sims[list(top_indices).index(idx)])}
170
+ for col in filtered_dataset.columns:
171
+ if col not in ['video_embed', 'description_embed', 'audio_embed']:
172
+ result[col] = getattr(row, col)
173
+ results.append(result)
174
+
175
+ return results
176
+
177
+ @st.cache_resource
178
+ def get_speech_model():
179
+ return edge_tts.Communicate
180
+
181
+ async def generate_speech(text, voice=None):
182
+ if not text.strip():
183
+ return None
184
+ if not voice:
185
+ voice = st.session_state['tts_voice']
186
+ try:
187
+ communicate = get_speech_model()(text, voice)
188
+ audio_file = f"speech_{datetime.now().strftime('%Y%m%d_%H%M%S')}.mp3"
189
+ await communicate.save(audio_file)
190
+ return audio_file
191
+ except Exception as e:
192
+ st.error(f"Error generating speech: {e}")
193
+ return None
194
+
195
+ def show_file_manager():
196
+ """Display file manager interface"""
197
+ st.subheader("πŸ“‚ File Manager")
198
+ col1, col2 = st.columns(2)
199
+ with col1:
200
+ uploaded_file = st.file_uploader("Upload File", type=['txt', 'md', 'mp3'])
201
+ if uploaded_file:
202
+ with open(uploaded_file.name, "wb") as f:
203
+ f.write(uploaded_file.getvalue())
204
+ st.success(f"Uploaded: {uploaded_file.name}")
205
+ st.experimental_rerun()
206
+
207
+ with col2:
208
+ if st.button("πŸ—‘ Clear All Files"):
209
+ for f in glob.glob("*.txt") + glob.glob("*.md") + glob.glob("*.mp3"):
210
+ os.remove(f)
211
+ st.success("All files cleared!")
212
+ st.experimental_rerun()
213
+
214
+ files = glob.glob("*.txt") + glob.glob("*.md") + glob.glob("*.mp3")
215
+ if files:
216
+ st.write("### Existing Files")
217
+ for f in files:
218
+ with st.expander(f"πŸ“„ {os.path.basename(f)}"):
219
+ if f.endswith('.mp3'):
220
+ st.audio(f)
221
+ else:
222
+ with open(f, 'r', encoding='utf-8') as file:
223
+ st.text_area("Content", file.read(), height=100)
224
+ if st.button(f"Delete {os.path.basename(f)}", key=f"del_{f}"):
225
+ os.remove(f)
226
+ st.experimental_rerun()
227
+
228
+ def arxiv_search(query, max_results=5):
229
+ """Perform a simple Arxiv search using their API and return top results."""
230
+ base_url = "http://export.arxiv.org/api/query?"
231
+ search_url = base_url + f"search_query={quote(query)}&start=0&max_results={max_results}"
232
+ r = requests.get(search_url)
233
+ if r.status_code == 200:
234
+ root = ET.fromstring(r.text)
235
+ ns = {'atom': 'http://www.w3.org/2005/Atom'}
236
+ entries = root.findall('atom:entry', ns)
237
+ results = []
238
+ for entry in entries:
239
+ title = entry.find('atom:title', ns).text.strip()
240
+ summary = entry.find('atom:summary', ns).text.strip()
241
+ link = None
242
+ for l in entry.findall('atom:link', ns):
243
+ if l.get('type') == 'text/html':
244
+ link = l.get('href')
245
+ break
246
+ results.append((title, summary, link))
247
+ return results
248
+ return []
249
+
250
+ def perform_arxiv_lookup(q, vocal_summary=True, titles_summary=True, full_audio=False):
251
+ results = arxiv_search(q, max_results=5)
252
+ if not results:
253
+ st.write("No Arxiv results found.")
254
+ return
255
+ st.markdown(f"**Arxiv Search Results for '{q}':**")
256
+ for i, (title, summary, link) in enumerate(results, start=1):
257
+ st.markdown(f"**{i}. {title}**")
258
+ st.write(summary)
259
+ if link:
260
+ st.markdown(f"[View Paper]({link})")
261
+
262
+ # TTS Options
263
+ if vocal_summary:
264
+ spoken_text = f"Here are some Arxiv results for {q}. "
265
+ if titles_summary:
266
+ spoken_text += " Titles: " + ", ".join([res[0] for res in results])
267
+ else:
268
+ spoken_text += " " + results[0][1][:200]
269
+
270
+ audio_file = asyncio.run(generate_speech(spoken_text))
271
+ if audio_file:
272
+ st.audio(audio_file)
273
+
274
+ if full_audio:
275
+ full_text = ""
276
+ for i,(title, summary, _) in enumerate(results, start=1):
277
+ full_text += f"Result {i}: {title}. {summary} "
278
+ audio_file_full = asyncio.run(generate_speech(full_text))
279
+ if audio_file_full:
280
+ st.write("### Full Audio")
281
+ st.audio(audio_file_full)
282
+
283
+ def main():
284
+ st.title("πŸŽ₯ Video & Arxiv Search with Voice Input")
285
+
286
+ search = VideoSearch()
287
+
288
+ tab1, tab2, tab3, tab4 = st.tabs(["πŸ” Search", "πŸŽ™οΈ Voice Input", "πŸ“š Arxiv", "πŸ“‚ Files"])
289
+
290
+ # ---- Tab 1: Video Search ----
291
+ with tab1:
292
+ st.subheader("Search Videos")
293
+ col1, col2 = st.columns([3, 1])
294
+ with col1:
295
+ query = st.text_input("Enter your search query:",
296
+ value="ancient" if not st.session_state['initial_search_done'] else "")
297
+ with col2:
298
+ search_column = st.selectbox("Search in field:",
299
+ ["All Fields"] + st.session_state['search_columns'])
300
+
301
+ col3, col4 = st.columns(2)
302
+ with col3:
303
+ num_results = st.slider("Number of results:", 1, 100, 20)
304
+ with col4:
305
+ search_button = st.button("πŸ” Search")
306
+
307
+ if (search_button or not st.session_state['initial_search_done']) and query:
308
+ st.session_state['initial_search_done'] = True
309
+ selected_column = None if search_column == "All Fields" else search_column
310
+ with st.spinner("Searching..."):
311
+ results = search.search(query, selected_column, num_results)
312
+
313
+ st.session_state['search_history'].append({
314
+ 'query': query,
315
+ 'timestamp': datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
316
+ 'results': results[:5]
317
+ })
318
+
319
+ for i, result in enumerate(results, 1):
320
+ # Highlight the query in the description
321
+ highlighted_desc = highlight_text(result['description'], query)
322
+ with st.expander(f"Result {i}: {result['description'][:100]}...", expanded=(i==1)):
323
+ cols = st.columns([2, 1])
324
+ with cols[0]:
325
+ st.markdown("**Description:**")
326
+ st.write(highlighted_desc)
327
+ st.markdown(f"**Time Range:** {result['start_time']}s - {result['end_time']}s")
328
+ st.markdown(f"**Views:** {result['views']:,}")
329
+
330
+ with cols[1]:
331
+ st.markdown(f"**Relevance Score:** {result['relevance_score']:.2%}")
332
+ if result.get('youtube_id'):
333
+ st.video(f"https://youtube.com/watch?v={result['youtube_id']}&t={result['start_time']}")
334
+
335
+ if st.button(f"πŸ”Š Audio Summary {i}", key=f"audio_{i}"):
336
+ summary = f"Video summary: {result['description'][:200]}"
337
+ audio_file = asyncio.run(generate_speech(summary))
338
+ if audio_file:
339
+ st.audio(audio_file)
340
+
341
+ # ---- Tab 2: Voice Input ----
342
+ # Reintroduce the mycomponent from earlier code for voice input accumulation
343
+ with tab2:
344
+ st.subheader("Voice Input (HTML Component)")
345
+
346
+ # Declare the custom component
347
+ mycomponent = components.declare_component("mycomponent", path="mycomponent")
348
+
349
+ # Use the component to get accumulated voice input
350
+ val = mycomponent(my_input_value="Hello")
351
+
352
+ if val:
353
+ val_stripped = val.replace('\n', ' ')
354
+ edited_input = st.text_area("✏️ Edit Input:", value=val_stripped, height=100)
355
+
356
+ # Just allow searching the videos from the edited input
357
+ if st.button("πŸ” Search from Edited Voice Input"):
358
+ results = search.search(edited_input, None, 20)
359
+ for i, result in enumerate(results, 1):
360
+ # Highlight query in description
361
+ highlighted_desc = highlight_text(result['description'], edited_input)
362
+ with st.expander(f"Result {i}", expanded=(i==1)):
363
+ st.write(highlighted_desc)
364
+ if result.get('youtube_id'):
365
+ st.video(f"https://youtube.com/watch?v={result['youtube_id']}&t={result.get('start_time', 0)}")
366
+
367
+ # Optionally also let user record audio via audio_recorder (not integrated with transcription)
368
+ st.write("Or record audio (not ASR integrated):")
369
+ audio_bytes = audio_recorder()
370
+ if audio_bytes:
371
+ audio_path = f"temp_audio_{datetime.now().strftime('%Y%m%d_%H%M%S')}.wav"
372
+ with open(audio_path, "wb") as f:
373
+ f.write(audio_bytes)
374
+ st.success("Audio recorded successfully!")
375
+ # No transcription is provided since no external ASR is included here.
376
+ if os.path.exists(audio_path):
377
+ os.remove(audio_path)
378
+
379
+ # ---- Tab 3: Arxiv Search ----
380
+ with tab3:
381
+ st.subheader("Arxiv Search")
382
+ q = st.text_input("Enter your Arxiv search query:", value=st.session_state['arxiv_last_query'])
383
+ vocal_summary = st.checkbox("πŸŽ™ Short Audio Summary", value=True)
384
+ titles_summary = st.checkbox("πŸ”– Titles Only", value=True)
385
+ full_audio = st.checkbox("πŸ“š Full Audio Results", value=False)
386
+
387
+ if st.button("πŸ” Arxiv Search"):
388
+ st.session_state['arxiv_last_query'] = q
389
+ perform_arxiv_lookup(q, vocal_summary=vocal_summary, titles_summary=titles_summary, full_audio=full_audio)
390
+
391
+ # ---- Tab 4: File Manager ----
392
+ with tab4:
393
+ show_file_manager()
394
+
395
+ # Sidebar
396
+ with st.sidebar:
397
+ st.subheader("βš™οΈ Settings & History")
398
+ if st.button("πŸ—‘οΈ Clear History"):
399
+ st.session_state['search_history'] = []
400
+ st.experimental_rerun()
401
+
402
+ st.markdown("### Recent Searches")
403
+ for entry in reversed(st.session_state['search_history'][-5:]):
404
+ with st.expander(f"{entry['timestamp']}: {entry['query']}"):
405
+ for i, result in enumerate(entry['results'], 1):
406
+ st.write(f"{i}. {result['description'][:100]}...")
407
+
408
+ st.markdown("### Voice Settings")
409
+ st.selectbox("TTS Voice:",
410
+ ["en-US-AriaNeural", "en-US-GuyNeural", "en-GB-SoniaNeural"],
411
+ key="tts_voice")
412
+
413
+ if __name__ == "__main__":
414
+ main()