Spaces:
Runtime error
Runtime error
File size: 3,306 Bytes
e749bbe bb448d0 442df1d e749bbe 442df1d bb448d0 e749bbe 442df1d b7be715 bb448d0 442df1d b7be715 442df1d d48ad83 e749bbe bb448d0 442df1d 93f9dc3 442df1d e749bbe d48ad83 e749bbe 5efced7 e749bbe d48ad83 e749bbe d48ad83 e749bbe d48ad83 e749bbe d48ad83 e749bbe 5bffb3b e749bbe d48ad83 e749bbe |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 |
import gradio as gr
import os
from huggingface_hub import hf_hub_download
from pathlib import Path
from transformers import GPT2Config, GPT2LMHeadModel, GPT2TokenizerFast
import json
model = GPT2LMHeadModel.from_pretrained('gpt2')
tokenizer = GPT2TokenizerFast.from_pretrained("gpt2")
tokenizer.pad_token = tokenizer.eos_token
logits_dict = {}
json_file = 'index.json'
with open(json_file, 'r') as file:
data = json.load(file)
for key, value in data.items():
text_description = value['text_description']
inputs = tokenizer(text_description, return_tensors="pt", padding="max_length", max_length=128, truncation=True)
outputs = model(**inputs, labels=inputs["input_ids"])
logits = outputs.logits
logits_dict[key] = logits
def search_index(query):
inputs = tokenizer(query, return_tensors="pt", padding="max_length", max_length=128, truncation=True)
outputs = model(**inputs, labels=inputs["input_ids"])
max_similarity = float('-inf')
max_similarity_uuid = None
for uuid, logits in logits_dict.items():
similarity = (outputs.logits * logits).sum()
if similarity > max_similarity:
max_similarity = similarity
max_similarity_uuid = uuid
gr.logger.info(f"Query: {query}")
return max_similarity_uuid
def download_video(uuid):
dataset_name = "quchenyuan/360x_dataset"
dataset_path = "360_dataset/binocular/"
video_filename = f"{uuid}.mp4"
# 确保存储目录存在
storage_dir = Path("videos")
storage_dir.mkdir(exist_ok=True)
storage_limit = 40 * 1024 * 1024 * 1024
current_storage = sum(f.stat().st_size for f in storage_dir.glob('*') if f.is_file())
if current_storage + os.path.getsize(video_filename) > storage_limit:
oldest_file = min(storage_dir.glob('*'), key=os.path.getmtime)
oldest_file.unlink()
downloaded_file_path = hf_hub_download(dataset_name, dataset_path + video_filename)
return str(storage_dir / video_filename)
# Gradio 接口函数
def search_and_show_video(query):
uuid = search_index(query)
video_path = download_video(uuid)
return video_path
if __name__ == "__main__":
with gr.Blocks() as demo:
with gr.Column():
with gr.Row():
gr.HTML("<h1><i>360+x</i> : A Panoptic Multi-modal Scene Understanding Dataset</h1>")
with gr.Row():
gr.HTML("<p><a href='https://x360dataset.github.io/'>Official Website</a> <a href='https://arxiv.org/abs/2404.00989'>Paper</a></p>")
with gr.Row():
gr.HTML("<h2>Search for a video by entering a query below:</h2>")
with gr.Row():
search_input = gr.Textbox(label="Query", placeholder="Enter a query to search for a video.")
with gr.Row():
with gr.Column():
video_output_1 = gr.Video()
with gr.Column():
video_output_2 = gr.Video()
with gr.Column():
video_output_3 = gr.Video()
with gr.Row():
submit_button = gr.Button(value="Search")
submit_button.click(search_and_show_video, search_input,
outputs=[video_output_1, video_output_2, video_output_3])
demo.launch()
|