Nicous commited on
Commit
c1983fc
·
verified ·
1 Parent(s): bd59431

Update app.py

Browse files

update demo test

Files changed (1) hide show
  1. app.py +321 -191
app.py CHANGED
@@ -5,8 +5,13 @@ import re
5
  import sys
6
  import copy
7
  import warnings
 
8
  from typing import Optional
9
 
 
 
 
 
10
  # Third-party imports
11
  import numpy as np
12
  import torch
@@ -22,6 +27,9 @@ from transformers import AutoModelForCausalLM, AutoTokenizer
22
 
23
  import gradio as gr
24
  import spaces
 
 
 
25
 
26
  # Local imports
27
  from egogpt.model.builder import load_pretrained_model
@@ -45,19 +53,21 @@ from huggingface_hub import snapshot_download
45
  # ignore_patterns=["*.md", "*.txt"] # 可以忽略一些不必要的文件(可选)
46
  # )
47
 
48
- from huggingface_hub import hf_hub_download
49
 
50
- # Download the model checkpoint file (large-v3.pt)
51
- ego_gpt_path = hf_hub_download(
52
- repo_id="EgoLife-v1/EgoGPT",
53
- filename="large-v3.pt",
54
- local_dir="./"
55
- )
56
 
57
 
58
  # pretrained = "/mnt/sfs-common/jkyang/EgoGPT/checkpoints/EgoGPT-llavaov-7b-EgoIT-109k-release"
59
  # pretrained = "/mnt/sfs-common/jkyang/EgoGPT/checkpoints/EgoGPT-llavaov-7b-EgoIT-EgoLife-Demo"
60
- pretrained = 'EgoLife-v1/EgoGPT'
 
 
61
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
62
  device_map = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
63
 
@@ -100,15 +110,8 @@ title_markdown = """
100
  """
101
  notice_html = """
102
  <div style="background-color: #f9f9f9; border-left: 5px solid #48dbfb; padding: 20px; margin-top: 20px; border-radius: 10px; box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1);">
103
- <ul style="list-style-type: none; padding-left: 0; font-size: 1.1em; color: #555;">
104
- <li>- Due to hardware limitations on this demo page, we recommend users only try 10-second videos.</li>
105
- <li>- The demo model is used for the egocentric video captioning step for the EgoRAG framework. The recommended prompt includes:</li>
106
- <ul style="padding-left: 20px; margin-top: 10px; color: #333;">
107
- <li>Can you help me log everything I do and the key things I see, like a personal journal? Describe them in a natural style.
108
- <li>Please provide your response using the first person, with "I" as the subject. Make sure the descriptions are detailed and natural.</li>
109
- <li>Can you write down important things I notice or interact with? Please respond in the first person, using "I" as the subject. Describe them in a natural style.</li>
110
- </ul>
111
- </ul>
112
  </div>
113
  """
114
 
@@ -174,15 +177,13 @@ def load_video(
174
  vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
175
  target_sr = 16000
176
 
177
- # Add new time-based processing logic
178
  if time_based_processing:
179
  # Initialize video reader
180
  vr = decord.VideoReader(video_path, ctx=decord.cpu(0), num_threads=1)
181
  total_frame_num = len(vr)
182
-
183
- # Get the actual FPS of the video
184
  video_fps = vr.get_avg_fps()
185
-
186
  # Convert time to frame index based on the actual video FPS
187
  video_start_frame = int(time_to_frame_idx(video_start_time, video_fps))
188
  start_frame = int(time_to_frame_idx(start_time, video_fps))
@@ -208,20 +209,8 @@ def load_video(
208
 
209
  # Get the video frames for the sampled indices
210
  video = vr.get_batch(frame_idx).asnumpy()
211
- target_sr = 16000 # Set target sample rate to 16kHz
212
-
213
- # Load audio from video with resampling
214
- y, _ = librosa.load(video_path, sr=target_sr)
215
-
216
- # Convert time to audio samples (using 16kHz sample rate)
217
- start_sample = int(start_time * target_sr)
218
- end_sample = int(end_time * target_sr)
219
-
220
- # Extract audio segment
221
- speech = y[start_sample:end_sample]
222
  else:
223
- # Original processing logic
224
- speech, _ = librosa.load(video_path, sr=target_sr)
225
  total_frame_num = len(vr)
226
  avg_fps = round(vr.get_avg_fps() / fps)
227
  frame_idx = [i for i in range(0, total_frame_num, avg_fps)]
@@ -233,12 +222,32 @@ def load_video(
233
 
234
  video = vr.get_batch(frame_idx).asnumpy()
235
 
236
- # Process audio
237
- speech = whisper.pad_or_trim(speech.astype(np.float32))
238
- speech = whisper.log_mel_spectrogram(speech, n_mels=128).permute(1, 0)
239
- speech_lengths = torch.LongTensor([speech.shape[0]])
240
-
241
- return video, speech, speech_lengths
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
242
 
243
  class PromptRequest(BaseModel):
244
  prompt: str
@@ -251,94 +260,188 @@ class PromptRequest(BaseModel):
251
  time_based_processing: bool = False
252
 
253
  # @spaces.GPU(duration=120)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
254
  def generate_text(video_path, audio_track, prompt):
 
 
255
  max_frames_num = 30
256
  fps = 1
257
- # model.eval()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
258
 
259
- # Video + speech branch
260
- conv_template = "qwen_1_5" # Make sure you use correct chat template for different models
261
- question = f"<image>\n{prompt}"
262
  conv = copy.deepcopy(conv_templates[conv_template])
263
  conv.append_message(conv.roles[0], question)
264
  conv.append_message(conv.roles[1], None)
265
  prompt_question = conv.get_prompt()
266
 
267
- video, speech, speech_lengths = load_video(
268
- video_path=video_path,
269
- max_frames_num=max_frames_num,
270
- fps=fps,
271
- )
272
- speech=torch.stack([speech]).to("cuda").half()
273
- processor = model.get_vision_tower().image_processor
274
- processed_video = processor.preprocess(video, return_tensors="pt")["pixel_values"]
275
- image = [(processed_video, video[0].size, "video")]
276
-
277
- print(prompt_question)
278
- parts=split_text(prompt_question,["<image>","<speech>"])
279
- input_ids=[]
280
  for part in parts:
281
- if "<image>"==part:
282
- input_ids+=[IMAGE_TOKEN_INDEX]
283
- elif "<speech>"==part:
284
- input_ids+=[SPEECH_TOKEN_INDEX]
285
  else:
286
- input_ids+=tokenizer(part).input_ids
287
-
288
- input_ids = torch.tensor(input_ids,dtype=torch.long).unsqueeze(0).to(device)
289
- image_tensor = [image[0][0].half()]
290
- image_sizes = [image[0][1]]
291
-
292
- generate_kwargs={"eos_token_id":tokenizer.eos_token_id}
293
- print(input_ids)
294
- cont = model.generate(
295
- input_ids,
296
- images=image_tensor,
297
- image_sizes=image_sizes,
298
- speech=speech,
299
- speech_lengths=speech_lengths,
300
- do_sample=False,
301
- temperature=0.5,
302
- max_new_tokens=4096,
303
- modalities=["video"],
304
- **generate_kwargs
305
- )
306
-
307
- text_outputs = tokenizer.batch_decode(cont, skip_special_tokens=True)
308
-
309
- return text_outputs[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
310
 
311
- def extract_audio_from_video(video_path, audio_path=None):
312
- if audio_path:
313
- try:
314
- y, sr = librosa.load(audio_path, sr=8000, mono=True, res_type='kaiser_fast')
315
- return (sr, y)
316
- except Exception as e:
317
- print(f"Error loading audio from {audio_path}: {e}")
318
- return None
319
- if video_path is None:
320
- return None
321
- try:
322
- y, sr = librosa.load(video_path, sr=8000, mono=True, res_type='kaiser_fast')
323
- return (sr, y)
324
- except Exception as e:
325
- print(f"Error extracting audio from video: {e}")
326
- return None
327
 
328
  head = """
 
 
 
 
329
  <style>
330
  /* Submit按钮默认和悬停效果 */
331
- button.lg.secondary.svelte-1gz44hr {
332
  background-color: #ff9933 !important;
333
  transition: background-color 0.3s ease !important;
334
  }
335
 
336
- button.lg.secondary.svelte-1gz44hr:hover {
337
  background-color: #ff7777 !important; /* 悬停时颜色加深 */
338
  }
339
 
340
  /* 确保按钮文字始终清晰可见 */
341
- button.lg.secondary.svelte-1gz44hr span {
342
  color: white !important;
343
  }
344
 
@@ -360,137 +463,164 @@ button.lg.secondary.svelte-1gz44hr span {
360
  </style>
361
 
362
  <script>
363
- // 新版同步控制代码
364
- function syncMediaElements() {
365
- // 获取视频和音频元素
366
- const video = document.querySelector('[data-testid="Video-player"] video');
367
- const waveform = document.querySelector('#waveform');
368
- const audio = waveform?.querySelector('audio') || waveform?.shadowRoot?.querySelector('audio');
369
-
370
- // 如果任一元素不存在,则退出
371
- if (!video || !audio) return;
372
-
373
- // 解除旧的事件监听(避免重复绑定)
374
- video.removeEventListener('play', syncPlay);
375
- audio.removeEventListener('play', syncPlay);
376
- video.removeEventListener('timeupdate', syncVideoTime);
377
- audio.removeEventListener('timeupdate', syncAudioTime);
378
-
379
- // 定义同步函数
380
- function syncPlay(e) {
381
- if(e.target === video && audio.paused) audio.play();
382
- if(e.target === audio && video.paused) video.play();
383
  }
384
 
385
- function syncVideoTime() {
386
- if(!audio.seeking && Math.abs(video.currentTime - audio.currentTime) > 0.1){
387
- audio.currentTime = video.currentTime;
388
- }
 
 
389
  }
 
390
 
391
- function syncAudioTime() {
392
- if(!video.seeking && Math.abs(audio.currentTime - video.currentTime) > 0.1){
393
- video.currentTime = audio.currentTime;
394
- }
395
  }
 
396
 
397
- // 绑定新的事件监听
398
- video.addEventListener('play', syncPlay);
399
- audio.addEventListener('play', syncPlay);
400
- video.addEventListener('timeupdate', syncVideoTime);
401
- audio.addEventListener('timeupdate', syncAudioTime);
 
402
 
403
- // 同步暂停事件
404
- video.addEventListener('pause', () => audio.pause());
405
- audio.addEventListener('pause', () => video.pause());
 
 
 
 
 
 
 
 
 
406
 
407
- console.log('Media elements synced successfully!');
 
 
 
 
408
  }
409
 
410
- // 智能DOM观察器
411
  const observer = new MutationObserver((mutations) => {
412
- mutations.forEach((mutation) => {
413
  if (mutation.addedNodes.length) {
414
- mutation.addedNodes.forEach((node) => {
415
- // 深度检查新增节点
416
- if (node.nodeType === 1) { // Element node
417
- // 检查是否包含视频组件
418
- if (node.querySelector?.('[data-testid="Video-player"]')) {
419
- // 当视频组件出现时,开始查找音频
420
- const audioObserver = new MutationObserver(() => {
421
- if(document.querySelector('#waveform audio')) {
422
- audioObserver.disconnect();
423
- setTimeout(syncMediaElements, 500); // 等待组件完全加载
424
- }
425
- });
426
- audioObserver.observe(document.body, {
427
- childList: true,
428
- subtree: true
429
- });
430
- }
431
- }
432
- });
433
  }
434
- });
435
  });
436
 
437
- // 开始观察整个文档
438
  observer.observe(document.body, {
439
  childList: true,
440
  subtree: true
441
  });
442
 
443
- // 初始检查(应对组件已存在的情况)
444
- setTimeout(() => {
445
- if(document.querySelector('[data-testid="Video-player"] video') &&
446
- document.querySelector('#waveform audio')){
447
- syncMediaElements();
 
 
 
 
 
 
 
 
 
448
  }
449
- }, 1000);
 
 
 
450
  </script>
451
  """
452
 
453
- with gr.Blocks(head=head) as demo:
454
  gr.HTML(title_markdown)
455
  gr.HTML(notice_html)
456
 
457
  with gr.Row():
458
  with gr.Column():
459
  video_input = gr.Video(label="Video", autoplay=True, loop=True, format="mp4", width=600, height=400, show_label=False, elem_id='video')
460
- # Audio input synchronized with video playback
461
  audio_display = gr.Audio(label="Video Audio Track", autoplay=False, show_label=True, visible=True, interactive=False, elem_id="audio")
462
- text_input = gr.Textbox(label="Question", placeholder="Enter your message here...")
463
 
464
- with gr.Column(): # Create a separate column for output and examples
465
  output_text = gr.Textbox(label="Response", lines=14, max_lines=14)
466
  gr.Examples(
467
  examples=[
468
- [f"{cur_dir}/bike.mp4", f"{cur_dir}/bike.mp3", "Can you tell me what I'm doing in short words. Describe them in a natural style."],
469
- [f"{cur_dir}/bike.mp4", f"{cur_dir}/bike.mp3", "Can you tell me what I'm doing in short words. Describe them in a natural style."],
470
- [f"{cur_dir}/bike.mp4", f"{cur_dir}/bike.mp3", "Can you tell me what I'm doing in short words. Describe them in a natural style."],
471
- [f"{cur_dir}/bike.mp4", f"{cur_dir}/bike.mp3", "Can you tell me what I'm doing in short words. Describe them in a natural style."]
472
  ],
473
  inputs=[video_input, audio_display, text_input],
474
  outputs=[output_text]
475
  )
476
 
477
- # Add event handler for video changes
 
 
 
 
 
 
 
 
478
  video_input.change(
479
- fn=lambda video_path: extract_audio_from_video(video_path, audio_path=None),
480
  inputs=[video_input],
481
- outputs=[audio_display]
482
  )
483
 
484
- # Add event handler for video clear/delete
485
  def clear_outputs(video):
486
- if video is None: # Video is cleared/deleted
487
- return ""
488
- return gr.skip() # Keep existing text if video exists
489
-
490
- video_input.change(
491
  fn=clear_outputs,
492
  inputs=[video_input],
493
- outputs=[output_text]
 
 
 
 
 
 
 
494
  )
495
 
496
  # Add submit button and its event handler
@@ -498,11 +628,11 @@ with gr.Blocks(head=head) as demo:
498
  submit_btn.click(
499
  fn=generate_text,
500
  inputs=[video_input, audio_display, text_input],
501
- outputs=[output_text]
 
502
  )
503
 
504
  gr.Markdown(bibtext)
505
  # Launch the Gradio app
506
  if __name__ == "__main__":
507
- demo.launch(share=True)
508
-
 
5
  import sys
6
  import copy
7
  import warnings
8
+ warnings.filterwarnings("ignore", category=UserWarning)
9
  from typing import Optional
10
 
11
+ import threading
12
+ from transformers import TextIteratorStreamer
13
+
14
+
15
  # Third-party imports
16
  import numpy as np
17
  import torch
 
27
 
28
  import gradio as gr
29
  import spaces
30
+ import json
31
+ from datetime import datetime
32
+ import shutil
33
 
34
  # Local imports
35
  from egogpt.model.builder import load_pretrained_model
 
53
  # ignore_patterns=["*.md", "*.txt"] # 可以忽略一些不必要的文件(可选)
54
  # )
55
 
56
+ # from huggingface_hub import hf_hub_download
57
 
58
+ # # Download the model checkpoint file (large-v3.pt)
59
+ # ego_gpt_path = hf_hub_download(
60
+ # repo_id="EgoLife-v1/EgoGPT",
61
+ # filename="large-v3.pt",
62
+ # local_dir="./"
63
+ # )
64
 
65
 
66
  # pretrained = "/mnt/sfs-common/jkyang/EgoGPT/checkpoints/EgoGPT-llavaov-7b-EgoIT-109k-release"
67
  # pretrained = "/mnt/sfs-common/jkyang/EgoGPT/checkpoints/EgoGPT-llavaov-7b-EgoIT-EgoLife-Demo"
68
+ # pretrained = 'EgoLife-v1/EgoGPT'
69
+ pretrained = 'EgoLife-v1/EgoGPT-0.5b-Demo'
70
+ # pretrained = "/mnt/sfs-common/jkyang/EgoGPT_release/checkpoints/EgoGPT-7b-Demo"
71
  device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
72
  device_map = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
73
 
 
110
  """
111
  notice_html = """
112
  <div style="background-color: #f9f9f9; border-left: 5px solid #48dbfb; padding: 20px; margin-top: 20px; border-radius: 10px; box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1);">
113
+ <p style="font-size: 1.1em; color: #ff9933; margin-bottom: 10px; font-weight: bold;">💡 Pro Tip: Try accessing this demo from your phone's browser. You can use your phone's camera to capture and analyze egocentric videos, making the experience more interactive and personal.</p>
114
+ <p style="font-size: 1.1em; color: #555; margin-bottom: 10px;">EgoGPT-7B is built upon LLaVA-OV and has been finetuned on the EgoIT dataset and a partially de-identified EgoLife dataset. Its primary goal is to serve as an egocentric captioner, supporting EgoRAG for EgoLifeQA tasks. Please note that due to inherent biases in the EgoLife dataset, the model may occasionally hallucinate details about people in custom videos based on patterns from the training data (for example, describing someone as "wearing a blue t-shirt" or "with pink hair"). We are actively working on improving the model to make it more universally applicable and will continue to release updates regularly. If you're interested in contributing to the development of future iterations of EgoGPT or the EgoLife project, we welcome you to reach out and contact us. (Contact us at <a href="mailto:[email protected]">[email protected]</a>)</p>
 
 
 
 
 
 
 
115
  </div>
116
  """
117
 
 
177
  vr = VideoReader(video_path, ctx=cpu(0), num_threads=1)
178
  target_sr = 16000
179
 
180
+ # Process video frames first
181
  if time_based_processing:
182
  # Initialize video reader
183
  vr = decord.VideoReader(video_path, ctx=decord.cpu(0), num_threads=1)
184
  total_frame_num = len(vr)
 
 
185
  video_fps = vr.get_avg_fps()
186
+
187
  # Convert time to frame index based on the actual video FPS
188
  video_start_frame = int(time_to_frame_idx(video_start_time, video_fps))
189
  start_frame = int(time_to_frame_idx(start_time, video_fps))
 
209
 
210
  # Get the video frames for the sampled indices
211
  video = vr.get_batch(frame_idx).asnumpy()
 
 
 
 
 
 
 
 
 
 
 
212
  else:
213
+ # Original video processing logic
 
214
  total_frame_num = len(vr)
215
  avg_fps = round(vr.get_avg_fps() / fps)
216
  frame_idx = [i for i in range(0, total_frame_num, avg_fps)]
 
222
 
223
  video = vr.get_batch(frame_idx).asnumpy()
224
 
225
+ # Try to load audio, return None for speech if failed
226
+ try:
227
+ if time_based_processing:
228
+ y, _ = librosa.load(video_path, sr=target_sr)
229
+ start_sample = int(start_time * target_sr)
230
+ end_sample = int(end_time * target_sr)
231
+ speech = y[start_sample:end_sample]
232
+ else:
233
+ speech, _ = librosa.load(video_path, sr=target_sr)
234
+
235
+ # Process audio if it exists
236
+ speech = whisper.pad_or_trim(speech.astype(np.float32))
237
+ speech = whisper.log_mel_spectrogram(speech, n_mels=128).permute(1, 0)
238
+ speech_lengths = torch.LongTensor([speech.shape[0]])
239
+
240
+ return video, speech, speech_lengths, True # True indicates real audio
241
+
242
+ except Exception as e:
243
+ print(f"Warning: Could not load audio from video: {e}")
244
+ # Create dummy silent audio
245
+ duration = 10 # 10 seconds
246
+ speech = np.zeros(duration * target_sr, dtype=np.float32)
247
+ speech = whisper.pad_or_trim(speech)
248
+ speech = whisper.log_mel_spectrogram(speech, n_mels=128).permute(1, 0)
249
+ speech_lengths = torch.LongTensor([speech.shape[0]])
250
+ return video, speech, speech_lengths, False # False indicates no real audio
251
 
252
  class PromptRequest(BaseModel):
253
  prompt: str
 
260
  time_based_processing: bool = False
261
 
262
  # @spaces.GPU(duration=120)
263
+ def save_interaction(video_path, prompt, output, audio_path=None):
264
+ """Save user interaction data and files"""
265
+ if not video_path:
266
+ return
267
+
268
+ # Create timestamped directory for this interaction
269
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
270
+ interaction_dir = os.path.join(UPLOADS_DIR, timestamp)
271
+ os.makedirs(interaction_dir, exist_ok=True)
272
+
273
+ # Copy video file
274
+ video_ext = os.path.splitext(video_path)[1]
275
+ new_video_path = os.path.join(interaction_dir, f"video{video_ext}")
276
+ shutil.copy2(video_path, new_video_path)
277
+
278
+ # Save metadata
279
+ metadata = {
280
+ "timestamp": timestamp,
281
+ "prompt": prompt,
282
+ "output": output,
283
+ "video_path": new_video_path,
284
+ }
285
+
286
+ # Only try to save audio if it's a file path (str), not audio data (tuple)
287
+ if audio_path and isinstance(audio_path, (str, bytes, os.PathLike)):
288
+ audio_ext = os.path.splitext(audio_path)[1]
289
+ new_audio_path = os.path.join(interaction_dir, f"audio{audio_ext}")
290
+ shutil.copy2(audio_path, new_audio_path)
291
+ metadata["audio_path"] = new_audio_path
292
+
293
+ with open(os.path.join(interaction_dir, "metadata.json"), "w") as f:
294
+ json.dump(metadata, f, indent=4)
295
+
296
+ def extract_audio_from_video(video_path, audio_path=None):
297
+ print('Processing audio from video...', video_path, audio_path)
298
+ if video_path is None:
299
+ return None
300
+
301
+ if isinstance(video_path, dict) and 'name' in video_path:
302
+ video_path = video_path['name']
303
+
304
+ try:
305
+ y, sr = librosa.load(video_path, sr=8000, mono=True, res_type='kaiser_fast')
306
+ # Check if the audio is silent
307
+ if np.abs(y).mean() < 0.001:
308
+ print("Video appears to be silent")
309
+ return None
310
+ return (sr, y)
311
+ except Exception as e:
312
+ print(f"Warning: Could not extract audio from video: {e}")
313
+ return None
314
+
315
+ import time
316
+
317
  def generate_text(video_path, audio_track, prompt):
318
+ streamer = TextIteratorStreamer(tokenizer, skip_special_tokens=True)
319
+
320
  max_frames_num = 30
321
  fps = 1
322
+ conv_template = "qwen_1_5"
323
+ if video_path is None and audio_track is None:
324
+ question = prompt
325
+ speech = None
326
+ speech_lengths = None
327
+ has_real_audio = False
328
+ image = None
329
+ image_sizes= None
330
+ modalities = ["image"]
331
+ image_tensor=None
332
+ # Load video and potentially audio
333
+ else:
334
+ video, speech, speech_lengths, has_real_audio = load_video(
335
+ video_path=video_path,
336
+ max_frames_num=max_frames_num,
337
+ fps=fps,
338
+ )
339
+
340
+ # Prepare the prompt based on whether we have real audio
341
+ if not has_real_audio:
342
+ question = f"<image>\n{prompt}" # Video-only prompt
343
+ else:
344
+ question = f"<speech>\n<image>\n{prompt}" # Video + speech prompt
345
+
346
+ speech = torch.stack([speech]).to("cuda").half()
347
+ processor = model.get_vision_tower().image_processor
348
+ processed_video = processor.preprocess(video, return_tensors="pt")["pixel_values"]
349
+ image = [(processed_video, video[0].size, "video")]
350
+ image_tensor = [image[0][0].half()]
351
+ image_sizes = [image[0][1]]
352
+ modalities = ["video"]
353
 
 
 
 
354
  conv = copy.deepcopy(conv_templates[conv_template])
355
  conv.append_message(conv.roles[0], question)
356
  conv.append_message(conv.roles[1], None)
357
  prompt_question = conv.get_prompt()
358
 
359
+
360
+
361
+ parts = split_text(prompt_question, ["<image>", "<speech>"])
362
+ input_ids = []
 
 
 
 
 
 
 
 
 
363
  for part in parts:
364
+ if "<image>" == part:
365
+ input_ids += [IMAGE_TOKEN_INDEX]
366
+ elif "<speech>" == part and speech is not None: # Only add speech token if we have audio
367
+ input_ids += [SPEECH_TOKEN_INDEX]
368
  else:
369
+ input_ids += tokenizer(part).input_ids
370
+
371
+ input_ids = torch.tensor(input_ids, dtype=torch.long).unsqueeze(0).to(device)
372
+
373
+
374
+ generate_kwargs = {"eos_token_id": tokenizer.eos_token_id}
375
+
376
+ def generate_response():
377
+ model.generate(
378
+ input_ids,
379
+ images=image_tensor,
380
+ image_sizes=image_sizes,
381
+ speech=speech,
382
+ speech_lengths=speech_lengths,
383
+ do_sample=False,
384
+ temperature=0.7,
385
+ max_new_tokens=512,
386
+ repetition_penalty=1.2,
387
+ modalities=modalities,
388
+ streamer=streamer,
389
+ **generate_kwargs
390
+ )
391
+
392
+ # Start generation in a separate thread
393
+ thread = threading.Thread(target=generate_response)
394
+ thread.start()
395
+
396
+ # Stream the output word by word
397
+ generated_text = ""
398
+ partial_word = ""
399
+ cursor = "|"
400
+ cursor_visible = True
401
+ last_cursor_toggle = time.time()
402
+
403
+ for new_text in streamer:
404
+ partial_word += new_text
405
+ # Toggle the cursor visibility every 0.5 seconds
406
+ if time.time() - last_cursor_toggle > 0.5:
407
+ cursor_visible = not cursor_visible
408
+ last_cursor_toggle = time.time()
409
+ current_cursor = cursor if cursor_visible else " "
410
+ if partial_word.endswith(" ") or partial_word.endswith("\n"):
411
+ generated_text += partial_word
412
+ # Yield the current text with the cursor appended
413
+ yield generated_text + current_cursor
414
+ partial_word = ""
415
+ else:
416
+ # Yield the current text plus the partial word and the cursor
417
+ yield generated_text + partial_word + current_cursor
418
 
419
+ # Handle any remaining partial word at the end
420
+ if partial_word:
421
+ generated_text += partial_word
422
+ yield generated_text
423
+
424
+ # Save the interaction after generation is complete
425
+ save_interaction(video_path, prompt, generated_text, audio_track)
 
 
 
 
 
 
 
 
 
426
 
427
  head = """
428
+ <head>
429
+ <title>EgoGPT Demo - EgoLife</title>
430
+ <link rel="icon" type="image/x-icon" href="./egolife_circle.ico">
431
+ </head>
432
  <style>
433
  /* Submit按钮默认和悬停效果 */
434
+ button.lg.secondary.svelte-5st68j {
435
  background-color: #ff9933 !important;
436
  transition: background-color 0.3s ease !important;
437
  }
438
 
439
+ button.lg.secondary.svelte-5st68j:hover {
440
  background-color: #ff7777 !important; /* 悬停时颜色加深 */
441
  }
442
 
443
  /* 确保按钮文字始终清晰可见 */
444
+ button.lg.secondary.svelte-5st68j span {
445
  color: white !important;
446
  }
447
 
 
463
  </style>
464
 
465
  <script>
466
+ function initializeControls() {
467
+ const video = document.querySelector('[data-testid="Video-player"]');
468
+ const waveform = document.getElementById('waveform');
469
+
470
+ // 如果元素还没准备好,直接返回
471
+ if (!video || !waveform) {
472
+ return;
473
+ }
474
+
475
+ // 尝试获取音频元素
476
+ const audio = waveform.querySelector('div')?.shadowRoot?.querySelector('audio');
477
+ if (!audio) {
478
+ return;
 
 
 
 
 
 
 
479
  }
480
 
481
+ console.log('Elements found:', { video, audio });
482
+
483
+ // 监听视频播放进度
484
+ video.addEventListener("play", () => {
485
+ if (audio.paused) {
486
+ audio.play(); // 如果音频暂停,开始播放
487
  }
488
+ });
489
 
490
+ // 监听音频播放进度
491
+ audio.addEventListener("play", () => {
492
+ if (video.paused) {
493
+ video.play(); // 如果视频暂停,开始播放
494
  }
495
+ });
496
 
497
+ // 同步视频和音频的播放进度
498
+ video.addEventListener("timeupdate", () => {
499
+ if (Math.abs(video.currentTime - audio.currentTime) > 0.1) {
500
+ audio.currentTime = video.currentTime; // 如果时间差超过0.1秒,同步
501
+ }
502
+ });
503
 
504
+ audio.addEventListener("timeupdate", () => {
505
+ if (Math.abs(audio.currentTime - video.currentTime) > 0.1) {
506
+ video.currentTime = audio.currentTime; // 如果时间差超过0.1秒,同步
507
+ }
508
+ });
509
+
510
+ // 监听暂停事件,确保视频和音频都暂停
511
+ video.addEventListener("pause", () => {
512
+ if (!audio.paused) {
513
+ audio.pause(); // 如果音频未暂停,暂停音频
514
+ }
515
+ });
516
 
517
+ audio.addEventListener("pause", () => {
518
+ if (!video.paused) {
519
+ video.pause(); // 如果视频未暂停,暂停视频
520
+ }
521
+ });
522
  }
523
 
524
+ // 创建观察器监听DOM变化
525
  const observer = new MutationObserver((mutations) => {
526
+ for (const mutation of mutations) {
527
  if (mutation.addedNodes.length) {
528
+ // 当有新节点添加时,尝试初始化
529
+ const waveform = document.getElementById('waveform');
530
+ if (waveform?.querySelector('div')?.shadowRoot?.querySelector('audio')) {
531
+ console.log('Audio element detected');
532
+ initializeControls();
533
+ // 可选:如果不需要继续监听,可以断开观察器
534
+ // observer.disconnect();
535
+ }
 
 
 
 
 
 
 
 
 
 
 
536
  }
537
+ }
538
  });
539
 
540
+ // 开始观察
541
  observer.observe(document.body, {
542
  childList: true,
543
  subtree: true
544
  });
545
 
546
+ // 页面加载完成时也尝试初始化
547
+ document.addEventListener('DOMContentLoaded', () => {
548
+ console.log('DOM Content Loaded');
549
+ initializeControls();
550
+
551
+ // Ensure title and favicon are set correctly
552
+ document.title = "EgoGPT Demo - EgoLife";
553
+
554
+ // Create/update favicon link
555
+ let link = document.querySelector("link[rel~='icon']");
556
+ if (!link) {
557
+ link = document.createElement('link');
558
+ link.rel = 'icon';
559
+ document.head.appendChild(link);
560
  }
561
+ link.href = './egolife_circle.ico';
562
+
563
+ });
564
+
565
  </script>
566
  """
567
 
568
+ with gr.Blocks(title="EgoGPT Demo - EgoLife", head=head) as demo:
569
  gr.HTML(title_markdown)
570
  gr.HTML(notice_html)
571
 
572
  with gr.Row():
573
  with gr.Column():
574
  video_input = gr.Video(label="Video", autoplay=True, loop=True, format="mp4", width=600, height=400, show_label=False, elem_id='video')
575
+ # Make audio display conditionally visible
576
  audio_display = gr.Audio(label="Video Audio Track", autoplay=False, show_label=True, visible=True, interactive=False, elem_id="audio")
577
+ text_input = gr.Textbox(label="Question", placeholder="Enter your message here...", value="Describe everything I saw, did, and heard, using the first perspective. Transcribe all the speech.")
578
 
579
+ with gr.Column():
580
  output_text = gr.Textbox(label="Response", lines=14, max_lines=14)
581
  gr.Examples(
582
  examples=[
583
+ [f"{cur_dir}/videos/cheers.mp4", f"{cur_dir}/videos/cheers.mp3", "Describe everything I saw, did, and heard from the first perspective."],
584
+ [f"{cur_dir}/videos/DAY3_A6_SHURE_14550000.mp4", f"{cur_dir}/videos/DAY3_A6_SHURE_14550000.mp3", "请按照时间顺序描述我所见所为,并转录所有声音。"],
585
+ [f"{cur_dir}/videos/shopping.mp4", f"{cur_dir}/videos/shopping.mp3", "Please only transcribe all the speech."],
586
+ [f"{cur_dir}/videos/japan.mp4", f"{cur_dir}/videos/japan.mp3", "Describe everything I see, do, and hear from the first-person view."],
587
  ],
588
  inputs=[video_input, audio_display, text_input],
589
  outputs=[output_text]
590
  )
591
 
592
+ def handle_video_change(video):
593
+ if video is None:
594
+ return gr.update(visible=False), None
595
+
596
+ audio = extract_audio_from_video(video)
597
+ # Update audio display visibility based on whether we have audio
598
+ return gr.update(visible=audio is not None), audio
599
+
600
+ # Update the video input change event
601
  video_input.change(
602
+ fn=handle_video_change,
603
  inputs=[video_input],
604
+ outputs=[audio_display, audio_display] # First for visibility, second for audio data
605
  )
606
 
607
+ # Add clear handler
608
  def clear_outputs(video):
609
+ if video is None:
610
+ return gr.update(visible=False), "", None
611
+ return gr.skip()
612
+
613
+ video_input.clear(
614
  fn=clear_outputs,
615
  inputs=[video_input],
616
+ outputs=[audio_display, output_text, audio_display]
617
+ )
618
+
619
+ text_input.submit(
620
+ fn=generate_text,
621
+ inputs=[video_input, audio_display, text_input],
622
+ outputs=[output_text],
623
+ api_name="generate_streaming"
624
  )
625
 
626
  # Add submit button and its event handler
 
628
  submit_btn.click(
629
  fn=generate_text,
630
  inputs=[video_input, audio_display, text_input],
631
+ outputs=[output_text],
632
+ api_name="generate_streaming"
633
  )
634
 
635
  gr.Markdown(bibtext)
636
  # Launch the Gradio app
637
  if __name__ == "__main__":
638
+ demo.launch(share=True)