ynhe's picture
Update app.py
b27d6ec verified
import os
import json
import shutil
import gradio as gr
import random
from huggingface_hub import Repository,HfApi
from huggingface_hub import snapshot_download
# from datasets import load_dataset
from datasets import config
hf_token = os.environ['hf_token'] # 确保环境变量中有你的令牌
local_dir = "VBench_sampled_video" # 本地文件夹路径
# dataset = load_dataset("Vchitect/VBench_sampled_video")
# print(os.listdir("~/.cache/huggingface/datasets/Vchitect___VBench_sampled_video/"))
# root = "~/.cache/huggingface/datasets/Vchitect___VBench_sampled_video/"
# print(config.HF_DATASETS_CACHE)
# root = config.HF_DATASETS_CACHE
# print(root)
def print_directory_contents(path, indent=0):
# 打印当前目录的内容
try:
for item in os.listdir(path):
item_path = os.path.join(path, item)
print(' ' * indent + item) # 使用缩进打印文件或文件夹
if os.path.isdir(item_path): # 如果是目录,则递归调用
print_directory_contents(item_path, indent + 1)
except PermissionError:
print(' ' * indent + "[权限错误,无法访问该目录]")
# 拉取数据集
os.makedirs(local_dir, exist_ok=True)
hf_api = HfApi(endpoint="https://huggingface.co", token=hf_token)
hf_api = HfApi(token=hf_token)
repo_id = "Vchitect/VBench_sampled_video"
model_names=[]
for i in hf_api.list_repo_tree('Vchitect/VBench_sampled_video',repo_type='dataset'):
model_name = i.path
if '.git' not in model_name and '.md' not in model_name:
model_names.append(model_name)
with open("videos_by_dimension.json") as f:
dimension = json.load(f)['videos_by_dimension']
for key in dimension:
new_item = []
for item in dimension[key]:
new_item.append(os.path.basename(item))
dimension[key] = new_item
# with open("all_videos.json") as f:
# all_videos = json.load(f)
types = ['appearance_style', 'color', 'temporal_style', 'spatial_relationship', 'temporal_flickering', 'scene', 'multiple_objects', 'object_class', 'human_action', 'overall_consistency', 'subject_consistency']
def get_video_path_local(model_name, type, prompt):
if 'Show-1' in model_name:
video_path_subfolder = os.path.join(model_name, type, 'super2')
elif 'videocrafter-1' in model_name:
video_path_subfolder = os.path.join(model_name, type, '1024x576')
else:
video_path_subfolder = os.path.join(model_name, type)
if model_name == 'cogvideo':
prompt = prompt.replace(".mp4",".gif")
try:
return hf_api.hf_hub_download(
repo_id = repo_id,
filename = prompt,
subfolder = video_path_subfolder,
repo_type = "dataset",
local_dir = local_dir
)
except Exception as e:
print(f"[PATH]{video_path_subfolder}/{prompt} NOT in hf repo, try {model_name}",e)
video_path_subfolder = model_name
try:
return hf_api.hf_hub_download(
repo_id = repo_id,
filename = prompt,
subfolder = video_path_subfolder,
repo_type = 'dataset',
local_dir = local_dir
)
except Exception as e:
print(f"[PATH]{video_path_subfolder}/{prompt} NOT in hf repo, try {model_name}",e)
print(e)
# video_path = dataset['train'][random_index]['video_path']
print('error:', model_name, type, prompt)
return None
def get_random_video():
# 随机选择一个索引
random_index = random.randint(0, len(types) - 1)
type = types[random_index]
# 随机选择一个Prompt
random_index = random.randint(0, len(dimension[type]) - 1)
prompt = dimension[type][random_index]
prompt = os.path.basename(prompt)
# 随机选择两个不同的模型名称
random_model_names = random.sample(model_names, 2)
model_name_1, model_name_2 = random_model_names
video_path1 = get_video_path_local(model_name_1, type, prompt)
video_path2 = get_video_path_local(model_name_2, type, prompt)
return video_path1, video_path2, model_name_1, model_name_2, type, prompt
def update_prompt_options(type, value=None):
if value:
return gr.update(choices=dimension[type], value=value if dimension[type] else None)
else:
return gr.update(choices=dimension[type], value=dimension[type][0] if dimension[type] else None)
def display_videos(type, prompt, model_name_1, model_name_2):
video_path1 = get_video_path_local(model_name_1, type, prompt)
video_path2 = get_video_path_local(model_name_2, type, prompt)
return video_path1, video_path2
def record_user_feedback_a(model_name1, model_name2, type, prompt):
# 0 means model A better, 1 means model B better, -1 means tie;
hf_api.hf_hub_download(
repo_id = "Vchitect/VBench_human_annotation",
filename = "arena_feedback.csv",
repo_type = "dataset",
local_dir = './'
)
with open("arena_feedback.csv",'a') as f:
f.write(f"{model_name1}\t{model_name2}\t{type}\t{prompt}\t{0}\n")
hf_api.upload_file(
path_or_fileobj="arena_feedback.csv",
path_in_repo="arena_feedback.csv",
repo_id="Vchitect/VBench_human_annotation",
token=hf_token,
repo_type="dataset",
commit_message="[From VBench Arena] user feedback",
)
return gr.update(visible=False),gr.update(visible=False),gr.update(visible=False)
def record_user_feedback_b(model_name1, model_name2, type, prompt):
# 0 means model A better, 1 means model B better , -1 means tie;
hf_api.hf_hub_download(
repo_id = "Vchitect/VBench_human_annotation",
filename = "arena_feedback.csv",
repo_type = "dataset",
local_dir = './'
)
with open("arena_feedback.csv",'a') as f:
f.write(f"{model_name1}\t{model_name2}\t{type}\t{prompt}\t{1}\n")
hf_api.upload_file(
path_or_fileobj="arena_feedback.csv",
path_in_repo="arena_feedback.csv",
repo_id="Vchitect/VBench_human_annotation",
token=hf_token,
repo_type="dataset",
commit_message="[From VBench Arena] user feedback",
)
return gr.update(visible=False),gr.update(visible=False),gr.update(visible=False)
def record_user_feedback_tie(model_name1, model_name2, type, prompt):
# 0 means model A better, 1 means model B better , -1 means tie;
hf_api.hf_hub_download(
repo_id = "Vchitect/VBench_human_annotation",
filename = "arena_feedback.csv",
repo_type = "dataset",
local_dir = './'
)
with open("arena_feedback.csv",'a') as f:
f.write(f"{model_name1}\t{model_name2}\t{type}\t{prompt}\t{-1}\n")
hf_api.upload_file(
path_or_fileobj="arena_feedback.csv",
path_in_repo="arena_feedback.csv",
repo_id="Vchitect/VBench_human_annotation",
token=hf_token,
repo_type="dataset",
commit_message="[From VBench Arena] user feedback",
)
return gr.update(visible=False),gr.update(visible=False),gr.update(visible=False)
def show_feedback_button():
return gr.update(visible=True),gr.update(visible=True),gr.update(visible=True)
with gr.Blocks() as interface:
gr.Markdown("# VBench Video Arena")
gr.Markdown("""
**VBench Video Arena: Watch AI-Generated Videos Instantly** (powered by [VBench](https://github.com/Vchitect/VBench) and [VBench Leaderboard](https://huggingface.co/spaces/Vchitect/VBench_Leaderboard))\n
- **Random 2 Videos**: Randomly selects two models to compare on the same ability dimension and text prompt.\n
- **Play Selection** Allows users to choose a model, dimension, and text prompt from drop-down menus and view the corresponding videos. """)
type_output = gr.Dropdown(label="Ability Dimension", choices=types, value=types[0])
prompt_output = gr.Dropdown(label="Text Prompt", choices=dimension[types[0]], value=dimension[types[0]][0])
prompt_placeholder = gr.State()
with gr.Row():
random_button = gr.Button("🎲 Random 2 Videos")
display_button = gr.Button("🎇 Play Selection")
with gr.Row():
with gr.Column():
model_name_1_output = gr.Dropdown(label="Model Name 1", choices=model_names, value=model_names[0])
video_output_1 = gr.Video(label="Video 1")
with gr.Column():
model_name_2_output = gr.Dropdown(label="Model Name 2", choices=model_names, value=model_names[1])
video_output_2 = gr.Video(label="Video 2")
with gr.Row():
feed0 = gr.Button("👈 Model A is better",visible=False)
feedt = gr.Button("😫 It's hard to say", visible=False)
feed1 = gr.Button("👉 Model B is better",visible=False)
type_output.change(fn=update_prompt_options, inputs=[type_output], outputs=[prompt_output])
random_button.click(
fn=get_random_video,
outputs=[video_output_1, video_output_2,model_name_1_output, model_name_2_output, type_output, prompt_placeholder]
).then(fn=update_prompt_options,
inputs=[type_output],
outputs=[prompt_output]
).then(fn=update_prompt_options,
inputs=[type_output,prompt_placeholder],
outputs=[prompt_output]
).then(
fn= show_feedback_button,
outputs=[feed0, feedt, feed1]
)
display_button.click(
fn=display_videos,
inputs=[type_output, prompt_output, model_name_1_output, model_name_2_output],
outputs=[video_output_1, video_output_2]
)
feed0.click(
fn = record_user_feedback_a,
inputs=[model_name_1_output, model_name_2_output, type_output, prompt_placeholder],
outputs=[feed0, feedt, feed1]
)
feed1.click(
fn = record_user_feedback_b,
inputs=[model_name_1_output, model_name_2_output, type_output, prompt_placeholder],
outputs=[feed0, feedt, feed1]
)
feedt.click(
fn = record_user_feedback_tie,
inputs=[model_name_1_output, model_name_2_output, type_output, prompt_placeholder],
outputs=[feed0, feedt, feed1]
)
interface.launch()