ynhe's picture
Update app.py
b27f379 verified
raw
history blame
3.48 kB
import os
import shutil
shutil.move("repository.py", "/usr/local/lib/python3.10/site-packages/huggingface_hub/repository.py")
import gradio as gr
import random
from huggingface_hub import Repository,HfApi
# from datasets import load_dataset
from datasets import config
hf_token = os.environ['hf_token'] # 确保环境变量中有你的令牌
local_dir = "VBench_sampled_video" # 本地文件夹路径
# dataset = load_dataset("Vchitect/VBench_sampled_video")
# print(os.listdir("~/.cache/huggingface/datasets/Vchitect___VBench_sampled_video/"))
# root = "~/.cache/huggingface/datasets/Vchitect___VBench_sampled_video/"
# print(config.HF_DATASETS_CACHE)
# root = config.HF_DATASETS_CACHE
# print(root)
def print_directory_contents(path, indent=0):
# 打印当前目录的内容
try:
for item in os.listdir(path):
item_path = os.path.join(path, item)
print(' ' * indent + item) # 使用缩进打印文件或文件夹
if os.path.isdir(item_path): # 如果是目录,则递归调用
print_directory_contents(item_path, indent + 1)
except PermissionError:
print(' ' * indent + "[权限错误,无法访问该目录]")
# 拉取数据集
os.makedirs(local_dir, exist_ok=True)
hf_api = HfApi(token=hf_token)
repo_id = "Vchitect/VBench_sampled_video"
dataset_files = api.list_repo_files(repo_id=repo_id, token=hf_token, repo_type='dataset')
for file in dataset_files:
print(file)
api.download_file(repo_id=repo_id, filename=file, token=hf_token,repo_type='dataset',cache_dir=local_dir)
# 下载数据集文件
for file in dataset_files:
api.download_file(repo_id=repo_id, filename=file, token=hf_token)
repo = HfApi(endpoint="https://huggingface.co", token=hf_token)
model_names = os.listdir(local_dir)
with open("videos_by_dimension.json") as f:
dimension = json.load(f)['videos_by_dimension']
# with open("all_videos.json") as f:
# all_videos = json.load(f)
types = ['appearance_style', 'color', 'temporal_style', 'spatial_relationship', 'temporal_flickering', 'scene', 'multiple_objects', 'object_class', 'human_action', 'overall_consistency', 'subject_consistency']
def get_random_video():
# 随机选择一个索引
random_index = random.randint(0, len(types) - 1)
type = types[random_index]
# 随机选择一个Prompt
random_index = random.randint(0, len(dimension[type]) - 1)
prompt = dimension[type][random_index]
# 随机一个模型
random_index = random.randint(0, len(model_names) - 1)
model_name = model_names[random_index]
video_path = os.path.join(model_name, type, prompt)
if os.path.exists(video_path):
print(video_path)
return video_path
else:
video_path = os.path.join(model_name, prompt)
if os.path.exists(video_path):
print(video_path)
return video_path
# video_path = dataset['train'][random_index]['video_path']
print('error:', video_path)
return video_path
# Gradio 接口
def display_video():
video_path = get_random_video()
return video_path
interface = gr.Interface(fn=display_video,
outputs=gr.Video(label="随机视频展示"),
inputs=[],
title="随机视频展示",
description="从 Vchitect/VBench_sampled_video 数据集中随机展示一个视频。")
if __name__ == "__main__":
interface.launch()