Spaces:
Running
on
Zero
Running
on
Zero
### ----------------- ### | |
# Standard library imports | |
import os | |
import re | |
import sys | |
import copy | |
import warnings | |
from typing import Optional | |
# Third-party imports | |
import numpy as np | |
import torch | |
import torch.distributed as dist | |
import uvicorn | |
import librosa | |
import whisper | |
import requests | |
from fastapi import FastAPI | |
from pydantic import BaseModel | |
from decord import VideoReader, cpu | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
import gradio as gr | |
import spaces | |
# Local imports | |
from egogpt.model.builder import load_pretrained_model | |
from egogpt.mm_utils import get_model_name_from_path, process_images | |
from egogpt.constants import ( | |
IMAGE_TOKEN_INDEX, | |
DEFAULT_IMAGE_TOKEN, | |
IGNORE_INDEX, | |
SPEECH_TOKEN_INDEX, | |
DEFAULT_SPEECH_TOKEN | |
) | |
from egogpt.conversation import conv_templates, SeparatorStyle | |
subprocess.run('pip install flash-attn --no-build-isolation', env={'FLASH_ATTENTION_SKIP_CUDA_BUILD': "TRUE"}, shell=True) | |
# pretrained = "/mnt/sfs-common/jkyang/EgoGPT/checkpoints/EgoGPT-llavaov-7b-EgoIT-109k-release" | |
# pretrained = "/mnt/sfs-common/jkyang/EgoGPT/checkpoints/EgoGPT-llavaov-7b-EgoIT-EgoLife-Demo" | |
pretrained = '/EgoLife-v1/EgoGPT' | |
device = "cuda" | |
device_map = "cuda" | |
# Add this initialization code before loading the model | |
def setup(rank, world_size): | |
os.environ['MASTER_ADDR'] = 'localhost' | |
os.environ['MASTER_PORT'] = '12377' | |
# initialize the process group | |
dist.init_process_group("gloo", rank=rank, world_size=world_size) | |
setup(0,1) | |
tokenizer, model, max_length = load_pretrained_model(pretrained,device_map=device_map) | |
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
model.to(device).eval() | |
title_markdown = """ | |
<div style="display: flex; justify-content: space-between; align-items: center; background: linear-gradient(90deg, rgba(72,219,251,0.1), rgba(29,209,161,0.1)); border-radius: 20px; box-shadow: 0 4px 6px rgba(0,0,0,0.1); padding: 20px; margin-bottom: 20px;"> | |
<div style="display: flex; align-items: center;"> | |
<a href="https://egolife-ntu.github.io/" style="margin-right: 20px; text-decoration: none; display: flex; align-items: center;"> | |
<img src="https://egolife-ntu.github.io/egolife.png" alt="EgoLife" style="max-width: 100px; height: auto; border-radius: 15px; box-shadow: 0 2px 4px rgba(0,0,0,0.1);"> | |
</a> | |
<div> | |
<h1 style="margin: 0; background: linear-gradient(90deg, #48dbfb, #1dd1a1); -webkit-background-clip: text; -webkit-text-fill-color: transparent; font-size: 2.5em; font-weight: 700;">EgoLife</h1> | |
<h2 style="margin: 10px 0; color: #2d3436; font-weight: 500;">Towards Egocentric Life Assistant</h2> | |
<div style="display: flex; gap: 15px; margin-top: 10px;"> | |
<a href="https://egolife-ntu.github.io/" style="text-decoration: none; color: #48dbfb; font-weight: 500; transition: color 0.3s;">Project Page</a> | | |
<a href="https://github.com/egolife-ntu/EgoGPT" style="text-decoration: none; color: #48dbfb; font-weight: 500; transition: color 0.3s;">Github</a> | | |
<a href="https://huggingface.co/lmms-lab" style="text-decoration: none; color: #48dbfb; font-weight: 500; transition: color 0.3s;">Huggingface</a> | | |
<a href="https://arxiv.org/" style="text-decoration: none; color: #48dbfb; font-weight: 500; transition: color 0.3s;">Paper</a> | | |
<a href="https://x.com/" style="text-decoration: none; color: #48dbfb; font-weight: 500; transition: color 0.3s;">Twitter (X)</a> | |
</div> | |
</div> | |
</div> | |
<div style="text-align: right; margin-left: 20px;"> | |
<h1 style="margin: 0; background: linear-gradient(90deg, #48dbfb, #1dd1a1); -webkit-background-clip: text; -webkit-text-fill-color: transparent; font-size: 2.5em; font-weight: 700;">EgoGPT</h1> | |
<h2 style="margin: 10px 0; background: linear-gradient(90deg, #48dbfb, #1dd1a1); -webkit-background-clip: text; -webkit-text-fill-color: transparent; font-size: 1.8em; font-weight: 600;">An Egocentric Video-Audio-Text Model<br>from EgoLife Project</h2> | |
</div> | |
</div> | |
""" | |
bibtext = """ | |
### Citation | |
``` | |
@article{yang2025egolife, | |
title={EgoLife\: Towards Egocentric Life Assistant}, | |
author={The EgoLife Team}, | |
journal={arXiv preprint arXiv:25xxx}, | |
year={2025} | |
} | |
``` | |
""" | |
# cur_dir = os.path.dirname(os.path.abspath(__file__)) | |
cur_dir = '' | |
def time_to_frame_idx(time_int: int, fps: int) -> int: | |
""" | |
Convert time in HHMMSSFF format (integer or string) to frame index. | |
:param time_int: Time in HHMMSSFF format, e.g., 10483000 (10:48:30.00) or "10483000". | |
:param fps: Frames per second of the video. | |
:return: Frame index corresponding to the given time. | |
""" | |
# Ensure time_int is a string for slicing | |
time_str = str(time_int).zfill( | |
8) # Pad with zeros if necessary to ensure it's 8 digits | |
hours = int(time_str[:2]) | |
minutes = int(time_str[2:4]) | |
seconds = int(time_str[4:6]) | |
frames = int(time_str[6:8]) | |
total_seconds = hours * 3600 + minutes * 60 + seconds | |
total_frames = total_seconds * fps + frames # Convert to total frames | |
return total_frames | |
def split_text(text, keywords): | |
# 创建一个正则表达式模式,将所有关键词用 | 连接,并使用捕获组 | |
pattern = '(' + '|'.join(map(re.escape, keywords)) + ')' | |
# 使用 re.split 保留分隔符 | |
parts = re.split(pattern, text) | |
# 去除空字符串 | |
parts = [part for part in parts if part] | |
return parts | |
warnings.filterwarnings("ignore") | |
# Create FastAPI instance | |
app = FastAPI() | |
def load_video( | |
video_path: Optional[str] = None, | |
max_frames_num: int = 16, | |
fps: int = 1, | |
video_start_time: Optional[float] = None, | |
start_time: Optional[float] = None, | |
end_time: Optional[float] = None, | |
time_based_processing: bool = False | |
) -> tuple: | |
vr = VideoReader(video_path, ctx=cpu(0), num_threads=1) | |
target_sr = 16000 | |
# Add new time-based processing logic | |
if time_based_processing: | |
# Initialize video reader | |
vr = decord.VideoReader(video_path, ctx=decord.cpu(0), num_threads=1) | |
total_frame_num = len(vr) | |
# Get the actual FPS of the video | |
video_fps = vr.get_avg_fps() | |
# Convert time to frame index based on the actual video FPS | |
video_start_frame = int(time_to_frame_idx(video_start_time, video_fps)) | |
start_frame = int(time_to_frame_idx(start_time, video_fps)) | |
end_frame = int(time_to_frame_idx(end_time, video_fps)) | |
print("start frame", start_frame) | |
print("end frame", end_frame) | |
# Ensure the end time does not exceed the total frame number | |
if end_frame - start_frame > total_frame_num: | |
end_frame = total_frame_num + start_frame | |
# Adjust start_frame and end_frame based on video start time | |
start_frame -= video_start_frame | |
end_frame -= video_start_frame | |
start_frame = max(0, int(round(start_frame))) # 确保不会小于0 | |
end_frame = min(total_frame_num, int(round(end_frame))) # 确保不会超过总帧数 | |
start_frame = int(round(start_frame)) | |
end_frame = int(round(end_frame)) | |
# Sample frames based on the provided fps (e.g., 1 frame per second) | |
frame_idx = [i for i in range(start_frame, end_frame) if (i - start_frame) % int(video_fps / fps) == 0] | |
# Get the video frames for the sampled indices | |
video = vr.get_batch(frame_idx).asnumpy() | |
target_sr = 16000 # Set target sample rate to 16kHz | |
# Load audio from video with resampling | |
y, _ = librosa.load(video_path, sr=target_sr) | |
# Convert time to audio samples (using 16kHz sample rate) | |
start_sample = int(start_time * target_sr) | |
end_sample = int(end_time * target_sr) | |
# Extract audio segment | |
speech = y[start_sample:end_sample] | |
else: | |
# Original processing logic | |
speech, _ = librosa.load(video_path, sr=target_sr) | |
total_frame_num = len(vr) | |
avg_fps = round(vr.get_avg_fps() / fps) | |
frame_idx = [i for i in range(0, total_frame_num, avg_fps)] | |
if max_frames_num > 0: | |
if len(frame_idx) > max_frames_num: | |
uniform_sampled_frames = np.linspace(0, total_frame_num - 1, max_frames_num, dtype=int) | |
frame_idx = uniform_sampled_frames.tolist() | |
video = vr.get_batch(frame_idx).asnumpy() | |
# Process audio | |
speech = whisper.pad_or_trim(speech.astype(np.float32)) | |
speech = whisper.log_mel_spectrogram(speech, n_mels=128).permute(1, 0) | |
speech_lengths = torch.LongTensor([speech.shape[0]]) | |
return video, speech, speech_lengths | |
class PromptRequest(BaseModel): | |
prompt: str | |
video_path: str = None | |
max_frames_num: int = 16 | |
fps: int = 1 | |
video_start_time: float = None | |
start_time: float = None | |
end_time: float = None | |
time_based_processing: bool = False | |
# @spaces.GPU(duration=120) | |
def generate_text(video_path, audio_track, prompt): | |
max_frames_num = 30 | |
fps = 1 | |
# model.eval() | |
# Video + speech branch | |
conv_template = "qwen_1_5" # Make sure you use correct chat template for different models | |
question = f"<image>\n{prompt}" | |
conv = copy.deepcopy(conv_templates[conv_template]) | |
conv.append_message(conv.roles[0], question) | |
conv.append_message(conv.roles[1], None) | |
prompt_question = conv.get_prompt() | |
video, speech, speech_lengths = load_video( | |
video_path=video_path, | |
max_frames_num=max_frames_num, | |
fps=fps, | |
) | |
speech=torch.stack([speech]).to("cuda").half() | |
processor = model.get_vision_tower().image_processor | |
processed_video = processor.preprocess(video, return_tensors="pt")["pixel_values"] | |
image = [(processed_video, video[0].size, "video")] | |
print(prompt_question) | |
parts=split_text(prompt_question,["<image>","<speech>"]) | |
input_ids=[] | |
for part in parts: | |
if "<image>"==part: | |
input_ids+=[IMAGE_TOKEN_INDEX] | |
elif "<speech>"==part: | |
input_ids+=[SPEECH_TOKEN_INDEX] | |
else: | |
input_ids+=tokenizer(part).input_ids | |
input_ids = torch.tensor(input_ids,dtype=torch.long).unsqueeze(0).to(device) | |
image_tensor = [image[0][0].half()] | |
image_sizes = [image[0][1]] | |
generate_kwargs={"eos_token_id":tokenizer.eos_token_id} | |
print(input_ids) | |
cont = model.generate( | |
input_ids, | |
images=image_tensor, | |
image_sizes=image_sizes, | |
speech=speech, | |
speech_lengths=speech_lengths, | |
do_sample=False, | |
temperature=0.5, | |
max_new_tokens=4096, | |
modalities=["video"], | |
**generate_kwargs | |
) | |
text_outputs = tokenizer.batch_decode(cont, skip_special_tokens=True) | |
return text_outputs[0] | |
def extract_audio_from_video(video_path, audio_path=None): | |
if audio_path: | |
try: | |
y, sr = librosa.load(audio_path, sr=8000, mono=True, res_type='kaiser_fast') | |
return (sr, y) | |
except Exception as e: | |
print(f"Error loading audio from {audio_path}: {e}") | |
return None | |
if video_path is None: | |
return None | |
try: | |
y, sr = librosa.load(video_path, sr=8000, mono=True, res_type='kaiser_fast') | |
return (sr, y) | |
except Exception as e: | |
print(f"Error extracting audio from video: {e}") | |
return None | |
head = """ | |
<style> | |
/* Submit按钮默认和悬停效果 */ | |
button.lg.secondary.svelte-5st68j { | |
background-color: #ff9933 !important; | |
transition: background-color 0.3s ease !important; | |
} | |
button.lg.secondary.svelte-5st68j:hover { | |
background-color: #ff7777 !important; /* 悬停时颜色加深 */ | |
} | |
/* 确保按钮文字始终清晰可见 */ | |
button.lg.secondary.svelte-5st68j span { | |
color: white !important; | |
} | |
/* 隐藏表头中的第二列 */ | |
.table-wrap .svelte-p5q82i th:nth-child(2) { | |
display: none; | |
} | |
/* 隐藏表格内容中的第二列 */ | |
.table-wrap .svelte-p5q82i td:nth-child(2) { | |
display: none; | |
} | |
.table-wrap { | |
max-height: 300px; | |
overflow-y: auto; | |
} | |
</style> | |
<script> | |
function initializeControls() { | |
const video = document.querySelector('[data-testid="Video-player"]'); | |
const waveform = document.getElementById('waveform'); | |
// 如果元素还没准备好,直接返回 | |
if (!video || !waveform) { | |
return; | |
} | |
// 尝试获取音频元素 | |
const audio = waveform.querySelector('div')?.shadowRoot?.querySelector('audio'); | |
if (!audio) { | |
return; | |
} | |
console.log('Elements found:', { video, audio }); | |
// 监听视频播放进度 | |
video.addEventListener("play", () => { | |
if (audio.paused) { | |
audio.play(); // 如果音频暂停,开始播放 | |
} | |
}); | |
// 监听音频播放进度 | |
audio.addEventListener("play", () => { | |
if (video.paused) { | |
video.play(); // 如果视频暂停,开始播放 | |
} | |
}); | |
// 同步视频和音频的播放进度 | |
video.addEventListener("timeupdate", () => { | |
if (Math.abs(video.currentTime - audio.currentTime) > 0.1) { | |
audio.currentTime = video.currentTime; // 如果时间差超过0.1秒,同步 | |
} | |
}); | |
audio.addEventListener("timeupdate", () => { | |
if (Math.abs(audio.currentTime - video.currentTime) > 0.1) { | |
video.currentTime = audio.currentTime; // 如果时间差超过0.1秒,同步 | |
} | |
}); | |
// 监听暂停事件,确保视频和音频都暂停 | |
video.addEventListener("pause", () => { | |
if (!audio.paused) { | |
audio.pause(); // 如果音频未暂停,暂停音频 | |
} | |
}); | |
audio.addEventListener("pause", () => { | |
if (!video.paused) { | |
video.pause(); // 如果视频未暂停,暂停视频 | |
} | |
}); | |
} | |
// 创建观察器监听DOM变化 | |
const observer = new MutationObserver((mutations) => { | |
for (const mutation of mutations) { | |
if (mutation.addedNodes.length) { | |
// 当有新节点添加时,尝试初始化 | |
const waveform = document.getElementById('waveform'); | |
if (waveform?.querySelector('div')?.shadowRoot?.querySelector('audio')) { | |
console.log('Audio element detected'); | |
initializeControls(); | |
// 可选:如果不需要继续监听,可以断开观察器 | |
// observer.disconnect(); | |
} | |
} | |
} | |
}); | |
// 开始观察 | |
observer.observe(document.body, { | |
childList: true, | |
subtree: true | |
}); | |
// 页面加载完成时也尝试初始化 | |
document.addEventListener('DOMContentLoaded', () => { | |
console.log('DOM Content Loaded'); | |
initializeControls(); | |
}); | |
</script> | |
""" | |
with gr.Blocks(head=head) as demo: | |
gr.Markdown(title_markdown) | |
with gr.Row(): | |
with gr.Column(): | |
video_input = gr.Video(label="Video", autoplay=True, loop=True, format="mp4", width=600, height=400, show_label=False, elem_id='video') | |
# Audio input synchronized with video playback | |
audio_display = gr.Audio(label="Video Audio Track", autoplay=False, show_label=True, visible=True, interactive=False, elem_id="audio") | |
text_input = gr.Textbox(label="Question", placeholder="Enter your message here...") | |
with gr.Column(): # Create a separate column for output and examples | |
output_text = gr.Textbox(label="Response", lines=14, max_lines=14) | |
gr.Examples( | |
examples=[ | |
[f"{cur_dir}/videos/bike.mp4", f"{cur_dir}/videos/bike.mp3", "Can you tell me what I'm doing in short words. Describe them in a natural style."], | |
[f"{cur_dir}/videos/bike.mp4", f"{cur_dir}/videos/bike.mp3", "Can you tell me what I'm doing in short words. Describe them in a natural style."], | |
[f"{cur_dir}/videos/bike.mp4", f"{cur_dir}/videos/bike.mp3", "Can you tell me what I'm doing in short words. Describe them in a natural style."], | |
[f"{cur_dir}/videos/bike.mp4", f"{cur_dir}/videos/bike.mp3", "Can you tell me what I'm doing in short words. Describe them in a natural style."] | |
], | |
inputs=[video_input, audio_display, text_input], | |
outputs=[output_text] | |
) | |
# Add event handler for video changes | |
video_input.change( | |
fn=lambda video_path: extract_audio_from_video(video_path, audio_path=None), | |
inputs=[video_input], | |
outputs=[audio_display] | |
) | |
# Add event handler for video clear/delete | |
def clear_outputs(video): | |
if video is None: # Video is cleared/deleted | |
return "" | |
return gr.skip() # Keep existing text if video exists | |
video_input.change( | |
fn=clear_outputs, | |
inputs=[video_input], | |
outputs=[output_text] | |
) | |
# Add submit button and its event handler | |
submit_btn = gr.Button("Submit") | |
submit_btn.click( | |
fn=generate_text, | |
inputs=[video_input, audio_display, text_input], | |
outputs=[output_text] | |
) | |
# Launch the Gradio app | |
if __name__ == "__main__": | |
demo.launch(share=True) | |