Spaces:
Runtime error
Runtime error
File size: 3,180 Bytes
e749bbe 442df1d e749bbe 442df1d e749bbe 442df1d d48ad83 e749bbe 442df1d e749bbe d48ad83 e749bbe 5efced7 e749bbe d48ad83 e749bbe d48ad83 e749bbe d48ad83 e749bbe d48ad83 e749bbe 5bffb3b e749bbe d48ad83 e749bbe |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 |
import gradio as gr
import os
from huggingface_hub import hf_hub_download
from pathlib import Path
from transformers import GPT2Config, GPT2LMHeadModel, GPT2Tokenizer
import json
model = GPT2LMHeadModel.from_pretrained('gpt2')
tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
logits_dict = {}
json_file = 'index.json'
with open(json_file, 'r') as file:
data = json.load(file)
for item in data:
uuid = item['uuid']
text_description = item['text_description']
inputs = tokenizer(text_description, return_tensors="pt", padding=True, truncation=True)
outputs = model(**inputs, labels=inputs["input_ids"])
logits = outputs.logits
logits_dict[uuid] = logits
def search_index(query):
inputs = tokenizer(query, return_tensors="pt", padding=True, truncation=True)
outputs = model(**inputs, labels=inputs["input_ids"])
max_similarity = float('-inf')
max_similarity_uuid = None
for uuid, logits in logits_dict.items():
similarity = (outputs.logits * logits).sum()
if similarity > max_similarity:
max_similarity = similarity
max_similarity_uuid = uuid
return max_similarity_uuid
def download_video(uuid):
dataset_name = "quchenyuan/360x_dataset"
dataset_path = "360_dataset/binocular/"
video_filename = f"{uuid}.mp4"
# 确保存储目录存在
storage_dir = Path("videos")
storage_dir.mkdir(exist_ok=True)
storage_limit = 40 * 1024 * 1024 * 1024
current_storage = sum(f.stat().st_size for f in storage_dir.glob('*') if f.is_file())
if current_storage + os.path.getsize(video_filename) > storage_limit:
oldest_file = min(storage_dir.glob('*'), key=os.path.getmtime)
oldest_file.unlink()
downloaded_file_path = hf_hub_download(dataset_name, dataset_path + video_filename)
return str(storage_dir / video_filename)
# Gradio 接口函数
def search_and_show_video(query):
uuid = search_index(query)
video_path = download_video(uuid)
return video_path
if __name__ == "__main__":
with gr.Blocks() as demo:
with gr.Column():
with gr.Row():
gr.HTML("<h1><i>360+x</i> : A Panoptic Multi-modal Scene Understanding Dataset</h1>")
with gr.Row():
gr.HTML("<p><a href='https://x360dataset.github.io/'>Official Website</a> <a href='https://arxiv.org/abs/2404.00989'>Paper</a></p>")
with gr.Row():
gr.HTML("<h2>Search for a video by entering a query below:</h2>")
with gr.Row():
search_input = gr.Textbox(label="Query", placeholder="Enter a query to search for a video.")
with gr.Row():
with gr.Column():
video_output_1 = gr.Video()
with gr.Column():
video_output_2 = gr.Video()
with gr.Column():
video_output_3 = gr.Video()
with gr.Row():
submit_button = gr.Button(value="Search")
submit_button.click(search_and_show_video, search_input,
outputs=[video_output_1, video_output_2, video_output_3])
demo.launch()
|