AiChat / sync_data.sh
leafmoes's picture
add timestamp
57b5848 verified
#!/bin/sh
# 定义应用名称
APP_NAME="AiChat"
# 黑白名单配置(逗号分隔),先进行白名单过滤,然后在白名单的基础上进行黑名单过滤
WHITELIST="webui.db"
BLACKLIST=""
if [ -z "$HF_TOKEN" ] || [ -z "$DATASET_ID" ]; then
echo "缺少必要的环境变量 HF_TOKEN 或 DATASET_ID"
exit 1
fi
mkdir -p "./data"
mkdir -p "/tmp/${APP_NAME}"
cat > /tmp/hf_sync.py << 'EOL'
import os
import sys
import hashlib
import shutil
from datetime import datetime
from zoneinfo import ZoneInfo
from huggingface_hub import HfApi, CommitOperationAdd, CommitOperationDelete
SHANG_HAI_TZ = ZoneInfo("Asia/Shanghai")
def log_print(*args, prefix="[SyncData]", **kwargs):
timestamp = datetime.now(SHANG_HAI_TZ).strftime("%Y-%m-%d %H:%M:%S")
print(f"[{timestamp}]",prefix, *args, **kwargs)
def calculate_file_hash(file_path):
"""
计算文件的 MD5 哈希值。
"""
try:
with open(file_path, 'rb') as f:
return hashlib.md5(f.read()).hexdigest()
except Exception as e:
log_print(f"Error calculating hash for {file_path}: {e}")
return None
def compare_and_sync_directories(source_dir, target_dir, whitelist=None, blacklist=None):
"""
比较 source_dir 和 target_dir 的文件哈希值。
如果不一致,将 source_dir 的文件同步到 target_dir。
返回需要上传的文件列表。
"""
files_to_upload = []
def should_include_path(path, whitelist, blacklist):
"""
检查路径是否应包含在内。
"""
if whitelist:
if not any(path.startswith(item.rstrip("/")) for item in whitelist):
return False
if blacklist:
if any(path.startswith(item.rstrip("/")) for item in blacklist):
return False
return True
def walk_and_filter(root_dir, rel_path="", whitelist=None, blacklist=None):
"""
遍历目录并根据黑白名单过滤文件。
"""
full_path = os.path.join(root_dir, rel_path)
if not os.path.exists(full_path):
return []
filtered_files = []
try:
entries = os.listdir(full_path)
for entry in entries:
entry_rel_path = os.path.join(rel_path, entry)
entry_full_path = os.path.join(full_path, entry)
if not should_include_path(entry_rel_path, whitelist, blacklist):
continue
if os.path.isdir(entry_full_path):
filtered_files.extend(walk_and_filter(root_dir, entry_rel_path, whitelist, blacklist))
else:
filtered_files.append(entry_rel_path)
except Exception as e:
log_print(f"Error processing directory {full_path}: {e}")
return filtered_files
source_files = {}
if os.path.exists(source_dir):
filtered_source_files = walk_and_filter(source_dir, whitelist=whitelist, blacklist=blacklist)
for relative_path in filtered_source_files:
file_path = os.path.join(source_dir, relative_path)
file_hash = calculate_file_hash(file_path)
if file_hash is not None:
source_files[relative_path] = file_hash
target_files = {}
if os.path.exists(target_dir):
for root, _, files in os.walk(target_dir):
for file in files:
file_path = os.path.join(root, file)
relative_path = os.path.relpath(file_path, target_dir)
file_hash = calculate_file_hash(file_path)
if file_hash is not None:
target_files[relative_path] = file_hash
for relative_path in source_files.keys():
source_file_path = os.path.join(source_dir, relative_path)
target_file_path = os.path.join(target_dir, relative_path)
if relative_path not in target_files or target_files[relative_path] != source_files[relative_path]:
os.makedirs(os.path.dirname(target_file_path), exist_ok=True)
shutil.copy2(source_file_path, target_file_path)
files_to_upload.append(relative_path)
for relative_path in target_files.keys():
if relative_path not in source_files:
target_file_path = os.path.join(target_dir, relative_path)
os.remove(target_file_path)
return files_to_upload
def upload_files(api, repo_id, local_dir, remote_dir, files_to_upload, operations):
"""
上传本地文件到远程仓库。
"""
for relative_path in files_to_upload:
local_file_path = os.path.join(local_dir, relative_path)
remote_file_path = os.path.join(remote_dir, relative_path)
operations.append(CommitOperationAdd(path_in_repo=remote_file_path, path_or_fileobj=local_file_path))
def delete_files(api, repo_id, remote_dir, files_to_delete, operations):
"""
删除远程仓库中的文件。
"""
for relative_path in files_to_delete:
remote_file_path = os.path.join(remote_dir, relative_path)
operations.append(CommitOperationDelete(path_in_repo=remote_file_path))
def commit_operations(api, repo_id, operations, commit_message):
"""
统一处理 Hugging Face 的 Commit 操作。
"""
try:
if operations:
api.create_commit(
repo_id=repo_id,
repo_type="dataset",
operations=operations,
commit_message=commit_message,
)
log_print("已成功更新云端版本!")
else:
log_print("当前版本已为最新,无需更新!")
except Exception as e:
log_print(f"更新提交失败: {str(e)}")
def download_files(api, repo_id, remote_dir, local_dir):
"""
从远程仓库下载文件到本地目录。
"""
try:
remote_files = api.list_repo_files(repo_id=repo_id, repo_type="dataset")
filtered_files = [file for file in remote_files if file.startswith(remote_dir)]
for remote_file in filtered_files:
relative_path = os.path.relpath(remote_file, remote_dir)
local_file_path = os.path.join(local_dir, relative_path)
os.makedirs(os.path.dirname(local_file_path), exist_ok=True)
tmp_filepath = api.hf_hub_download(
repo_id=repo_id,
filename=remote_file,
repo_type="dataset",
)
if tmp_filepath and os.path.exists(tmp_filepath):
shutil.copy2(tmp_filepath, local_file_path)
shutil.copy2(tmp_filepath, f"/tmp/{remote_file}")
log_print(f"已下载{remote_file} -> {local_file_path}")
except Exception as e:
log_print(f"下载失败: {str(e)}")
def sync_repository(api, repo_id, remote_dir, whitelist=None, blacklist=None):
"""
同步本地与远程仓库(单次执行)。
"""
log_print(f"开始数据同步进程...")
source_dir = "./data"
target_dir = f"/tmp/{remote_dir}"
files_to_upload = compare_and_sync_directories(source_dir, target_dir, whitelist, blacklist)
remote_files = api.list_repo_files(repo_id=repo_id, repo_type="dataset")
local_files = []
if os.path.exists(target_dir):
for root, _, files in os.walk(target_dir):
for file in files:
file_path = os.path.join(root, file)
relative_path = os.path.relpath(file_path, target_dir)
local_files.append(relative_path)
local_files_set = set(local_files)
files_to_delete = [
os.path.relpath(remote_file, remote_dir)
for remote_file in remote_files
if remote_file.startswith(remote_dir) and os.path.relpath(remote_file, remote_dir) not in local_files_set
]
operations = []
upload_files(api, repo_id, source_dir, remote_dir, files_to_upload, operations)
delete_files(api, repo_id, remote_dir, files_to_delete, operations)
commit_operations(api, repo_id, operations, f"Sync repository: {remote_dir}")
if files_to_upload:
log_print(f"文件已上传: {files_to_upload}")
if files_to_delete:
log_print(f"文件已删除: {files_to_delete}")
if __name__ == "__main__":
action = sys.argv[1]
token = sys.argv[2]
repo_id = sys.argv[3]
remote_dir = sys.argv[4]
api = HfApi(token=token)
source_dir = "./data"
if action == "sync":
whitelist = sys.argv[5].split(",") if len(sys.argv) > 5 and sys.argv[5] not in ["", "None"] else None
blacklist = sys.argv[6].split(",") if len(sys.argv) > 6 and sys.argv[6] not in ["", "None"] else None
sync_repository(api, repo_id, remote_dir, whitelist, blacklist)
elif action == "download":
download_files(api, repo_id, remote_dir, source_dir)
EOL
sync_data() {
SYNC_INTERVAL=${SYNC_INTERVAL:-7200} # 默认同步间隔为 7200 秒(2 小时)
while true; do
python3 /tmp/hf_sync.py sync "${HF_TOKEN}" "${DATASET_ID}" "${APP_NAME}" "${WHITELIST}" "${BLACKLIST}"
sleep "${SYNC_INTERVAL}"
done
}
python3 /tmp/hf_sync.py download "${HF_TOKEN}" "${DATASET_ID}" "${APP_NAME}"
sync_data &