|
#!/bin/sh |
|
|
|
|
|
APP_NAME="AiChat" |
|
|
|
|
|
WHITELIST="webui.db" |
|
BLACKLIST="" |
|
|
|
if [ -z "$HF_TOKEN" ] || [ -z "$DATASET_ID" ]; then |
|
echo "缺少必要的环境变量 HF_TOKEN 或 DATASET_ID" |
|
exit 1 |
|
fi |
|
|
|
mkdir -p "./data" |
|
mkdir -p "/tmp/${APP_NAME}" |
|
|
|
cat > /tmp/hf_sync.py << 'EOL' |
|
import os |
|
import sys |
|
import hashlib |
|
import shutil |
|
from datetime import datetime |
|
from zoneinfo import ZoneInfo |
|
from huggingface_hub import HfApi, CommitOperationAdd, CommitOperationDelete |
|
|
|
SHANG_HAI_TZ = ZoneInfo("Asia/Shanghai") |
|
|
|
def log_print(*args, prefix="[SyncData]", **kwargs): |
|
timestamp = datetime.now(SHANG_HAI_TZ).strftime("%Y-%m-%d %H:%M:%S") |
|
print(f"[{timestamp}]",prefix, *args, **kwargs) |
|
|
|
|
|
def calculate_file_hash(file_path): |
|
""" |
|
计算文件的 MD5 哈希值。 |
|
""" |
|
try: |
|
with open(file_path, 'rb') as f: |
|
return hashlib.md5(f.read()).hexdigest() |
|
except Exception as e: |
|
log_print(f"Error calculating hash for {file_path}: {e}") |
|
return None |
|
|
|
|
|
def compare_and_sync_directories(source_dir, target_dir, whitelist=None, blacklist=None): |
|
""" |
|
比较 source_dir 和 target_dir 的文件哈希值。 |
|
如果不一致,将 source_dir 的文件同步到 target_dir。 |
|
返回需要上传的文件列表。 |
|
""" |
|
files_to_upload = [] |
|
|
|
def should_include_path(path, whitelist, blacklist): |
|
""" |
|
检查路径是否应包含在内。 |
|
""" |
|
if whitelist: |
|
if not any(path.startswith(item.rstrip("/")) for item in whitelist): |
|
return False |
|
if blacklist: |
|
if any(path.startswith(item.rstrip("/")) for item in blacklist): |
|
return False |
|
return True |
|
|
|
def walk_and_filter(root_dir, rel_path="", whitelist=None, blacklist=None): |
|
""" |
|
遍历目录并根据黑白名单过滤文件。 |
|
""" |
|
full_path = os.path.join(root_dir, rel_path) |
|
if not os.path.exists(full_path): |
|
return [] |
|
|
|
filtered_files = [] |
|
try: |
|
entries = os.listdir(full_path) |
|
for entry in entries: |
|
entry_rel_path = os.path.join(rel_path, entry) |
|
entry_full_path = os.path.join(full_path, entry) |
|
|
|
if not should_include_path(entry_rel_path, whitelist, blacklist): |
|
continue |
|
|
|
if os.path.isdir(entry_full_path): |
|
filtered_files.extend(walk_and_filter(root_dir, entry_rel_path, whitelist, blacklist)) |
|
else: |
|
filtered_files.append(entry_rel_path) |
|
except Exception as e: |
|
log_print(f"Error processing directory {full_path}: {e}") |
|
return filtered_files |
|
|
|
source_files = {} |
|
if os.path.exists(source_dir): |
|
filtered_source_files = walk_and_filter(source_dir, whitelist=whitelist, blacklist=blacklist) |
|
for relative_path in filtered_source_files: |
|
file_path = os.path.join(source_dir, relative_path) |
|
file_hash = calculate_file_hash(file_path) |
|
if file_hash is not None: |
|
source_files[relative_path] = file_hash |
|
|
|
target_files = {} |
|
if os.path.exists(target_dir): |
|
for root, _, files in os.walk(target_dir): |
|
for file in files: |
|
file_path = os.path.join(root, file) |
|
relative_path = os.path.relpath(file_path, target_dir) |
|
file_hash = calculate_file_hash(file_path) |
|
if file_hash is not None: |
|
target_files[relative_path] = file_hash |
|
|
|
for relative_path in source_files.keys(): |
|
source_file_path = os.path.join(source_dir, relative_path) |
|
target_file_path = os.path.join(target_dir, relative_path) |
|
if relative_path not in target_files or target_files[relative_path] != source_files[relative_path]: |
|
os.makedirs(os.path.dirname(target_file_path), exist_ok=True) |
|
shutil.copy2(source_file_path, target_file_path) |
|
files_to_upload.append(relative_path) |
|
|
|
for relative_path in target_files.keys(): |
|
if relative_path not in source_files: |
|
target_file_path = os.path.join(target_dir, relative_path) |
|
os.remove(target_file_path) |
|
|
|
return files_to_upload |
|
|
|
|
|
def upload_files(api, repo_id, local_dir, remote_dir, files_to_upload, operations): |
|
""" |
|
上传本地文件到远程仓库。 |
|
""" |
|
for relative_path in files_to_upload: |
|
local_file_path = os.path.join(local_dir, relative_path) |
|
remote_file_path = os.path.join(remote_dir, relative_path) |
|
operations.append(CommitOperationAdd(path_in_repo=remote_file_path, path_or_fileobj=local_file_path)) |
|
|
|
|
|
def delete_files(api, repo_id, remote_dir, files_to_delete, operations): |
|
""" |
|
删除远程仓库中的文件。 |
|
""" |
|
for relative_path in files_to_delete: |
|
remote_file_path = os.path.join(remote_dir, relative_path) |
|
operations.append(CommitOperationDelete(path_in_repo=remote_file_path)) |
|
|
|
|
|
def commit_operations(api, repo_id, operations, commit_message): |
|
""" |
|
统一处理 Hugging Face 的 Commit 操作。 |
|
""" |
|
try: |
|
if operations: |
|
api.create_commit( |
|
repo_id=repo_id, |
|
repo_type="dataset", |
|
operations=operations, |
|
commit_message=commit_message, |
|
) |
|
log_print("已成功更新云端版本!") |
|
else: |
|
log_print("当前版本已为最新,无需更新!") |
|
except Exception as e: |
|
log_print(f"更新提交失败: {str(e)}") |
|
|
|
|
|
def download_files(api, repo_id, remote_dir, local_dir): |
|
""" |
|
从远程仓库下载文件到本地目录。 |
|
""" |
|
try: |
|
remote_files = api.list_repo_files(repo_id=repo_id, repo_type="dataset") |
|
filtered_files = [file for file in remote_files if file.startswith(remote_dir)] |
|
|
|
for remote_file in filtered_files: |
|
relative_path = os.path.relpath(remote_file, remote_dir) |
|
local_file_path = os.path.join(local_dir, relative_path) |
|
os.makedirs(os.path.dirname(local_file_path), exist_ok=True) |
|
tmp_filepath = api.hf_hub_download( |
|
repo_id=repo_id, |
|
filename=remote_file, |
|
repo_type="dataset", |
|
) |
|
if tmp_filepath and os.path.exists(tmp_filepath): |
|
shutil.copy2(tmp_filepath, local_file_path) |
|
shutil.copy2(tmp_filepath, f"/tmp/{remote_file}") |
|
|
|
log_print(f"已下载{remote_file} -> {local_file_path}") |
|
|
|
except Exception as e: |
|
log_print(f"下载失败: {str(e)}") |
|
|
|
|
|
def sync_repository(api, repo_id, remote_dir, whitelist=None, blacklist=None): |
|
""" |
|
同步本地与远程仓库(单次执行)。 |
|
""" |
|
log_print(f"开始数据同步进程...") |
|
source_dir = "./data" |
|
target_dir = f"/tmp/{remote_dir}" |
|
|
|
files_to_upload = compare_and_sync_directories(source_dir, target_dir, whitelist, blacklist) |
|
|
|
remote_files = api.list_repo_files(repo_id=repo_id, repo_type="dataset") |
|
local_files = [] |
|
if os.path.exists(target_dir): |
|
for root, _, files in os.walk(target_dir): |
|
for file in files: |
|
file_path = os.path.join(root, file) |
|
relative_path = os.path.relpath(file_path, target_dir) |
|
local_files.append(relative_path) |
|
local_files_set = set(local_files) |
|
|
|
files_to_delete = [ |
|
os.path.relpath(remote_file, remote_dir) |
|
for remote_file in remote_files |
|
if remote_file.startswith(remote_dir) and os.path.relpath(remote_file, remote_dir) not in local_files_set |
|
] |
|
|
|
operations = [] |
|
upload_files(api, repo_id, source_dir, remote_dir, files_to_upload, operations) |
|
delete_files(api, repo_id, remote_dir, files_to_delete, operations) |
|
|
|
commit_operations(api, repo_id, operations, f"Sync repository: {remote_dir}") |
|
|
|
if files_to_upload: |
|
log_print(f"文件已上传: {files_to_upload}") |
|
if files_to_delete: |
|
log_print(f"文件已删除: {files_to_delete}") |
|
|
|
|
|
if __name__ == "__main__": |
|
action = sys.argv[1] |
|
token = sys.argv[2] |
|
repo_id = sys.argv[3] |
|
remote_dir = sys.argv[4] |
|
api = HfApi(token=token) |
|
source_dir = "./data" |
|
|
|
if action == "sync": |
|
whitelist = sys.argv[5].split(",") if len(sys.argv) > 5 and sys.argv[5] not in ["", "None"] else None |
|
blacklist = sys.argv[6].split(",") if len(sys.argv) > 6 and sys.argv[6] not in ["", "None"] else None |
|
sync_repository(api, repo_id, remote_dir, whitelist, blacklist) |
|
elif action == "download": |
|
download_files(api, repo_id, remote_dir, source_dir) |
|
EOL |
|
|
|
sync_data() { |
|
SYNC_INTERVAL=${SYNC_INTERVAL:-7200} |
|
while true; do |
|
python3 /tmp/hf_sync.py sync "${HF_TOKEN}" "${DATASET_ID}" "${APP_NAME}" "${WHITELIST}" "${BLACKLIST}" |
|
sleep "${SYNC_INTERVAL}" |
|
done |
|
} |
|
|
|
python3 /tmp/hf_sync.py download "${HF_TOKEN}" "${DATASET_ID}" "${APP_NAME}" |
|
|
|
sync_data & |