#!/bin/sh # 定义应用名称 APP_NAME="AiChat" # 黑白名单配置(逗号分隔),先进行白名单过滤,然后在白名单的基础上进行黑名单过滤 WHITELIST="webui.db" BLACKLIST="" if [ -z "$HF_TOKEN" ] || [ -z "$DATASET_ID" ]; then echo "缺少必要的环境变量 HF_TOKEN 或 DATASET_ID" exit 1 fi mkdir -p "./data" mkdir -p "/tmp/${APP_NAME}" cat > /tmp/hf_sync.py << 'EOL' import os import sys import hashlib import shutil from datetime import datetime from zoneinfo import ZoneInfo from huggingface_hub import HfApi, CommitOperationAdd, CommitOperationDelete SHANG_HAI_TZ = ZoneInfo("Asia/Shanghai") def log_print(*args, prefix="[SyncData]", **kwargs): timestamp = datetime.now(SHANG_HAI_TZ).strftime("%Y-%m-%d %H:%M:%S") print(f"[{timestamp}]",prefix, *args, **kwargs) def calculate_file_hash(file_path): """ 计算文件的 MD5 哈希值。 """ try: with open(file_path, 'rb') as f: return hashlib.md5(f.read()).hexdigest() except Exception as e: log_print(f"Error calculating hash for {file_path}: {e}") return None def compare_and_sync_directories(source_dir, target_dir, whitelist=None, blacklist=None): """ 比较 source_dir 和 target_dir 的文件哈希值。 如果不一致,将 source_dir 的文件同步到 target_dir。 返回需要上传的文件列表。 """ files_to_upload = [] def should_include_path(path, whitelist, blacklist): """ 检查路径是否应包含在内。 """ if whitelist: if not any(path.startswith(item.rstrip("/")) for item in whitelist): return False if blacklist: if any(path.startswith(item.rstrip("/")) for item in blacklist): return False return True def walk_and_filter(root_dir, rel_path="", whitelist=None, blacklist=None): """ 遍历目录并根据黑白名单过滤文件。 """ full_path = os.path.join(root_dir, rel_path) if not os.path.exists(full_path): return [] filtered_files = [] try: entries = os.listdir(full_path) for entry in entries: entry_rel_path = os.path.join(rel_path, entry) entry_full_path = os.path.join(full_path, entry) if not should_include_path(entry_rel_path, whitelist, blacklist): continue if os.path.isdir(entry_full_path): filtered_files.extend(walk_and_filter(root_dir, entry_rel_path, whitelist, blacklist)) else: filtered_files.append(entry_rel_path) except Exception as e: log_print(f"Error processing directory {full_path}: {e}") return filtered_files source_files = {} if os.path.exists(source_dir): filtered_source_files = walk_and_filter(source_dir, whitelist=whitelist, blacklist=blacklist) for relative_path in filtered_source_files: file_path = os.path.join(source_dir, relative_path) file_hash = calculate_file_hash(file_path) if file_hash is not None: source_files[relative_path] = file_hash target_files = {} if os.path.exists(target_dir): for root, _, files in os.walk(target_dir): for file in files: file_path = os.path.join(root, file) relative_path = os.path.relpath(file_path, target_dir) file_hash = calculate_file_hash(file_path) if file_hash is not None: target_files[relative_path] = file_hash for relative_path in source_files.keys(): source_file_path = os.path.join(source_dir, relative_path) target_file_path = os.path.join(target_dir, relative_path) if relative_path not in target_files or target_files[relative_path] != source_files[relative_path]: os.makedirs(os.path.dirname(target_file_path), exist_ok=True) shutil.copy2(source_file_path, target_file_path) files_to_upload.append(relative_path) for relative_path in target_files.keys(): if relative_path not in source_files: target_file_path = os.path.join(target_dir, relative_path) os.remove(target_file_path) return files_to_upload def upload_files(api, repo_id, local_dir, remote_dir, files_to_upload, operations): """ 上传本地文件到远程仓库。 """ for relative_path in files_to_upload: local_file_path = os.path.join(local_dir, relative_path) remote_file_path = os.path.join(remote_dir, relative_path) operations.append(CommitOperationAdd(path_in_repo=remote_file_path, path_or_fileobj=local_file_path)) def delete_files(api, repo_id, remote_dir, files_to_delete, operations): """ 删除远程仓库中的文件。 """ for relative_path in files_to_delete: remote_file_path = os.path.join(remote_dir, relative_path) operations.append(CommitOperationDelete(path_in_repo=remote_file_path)) def commit_operations(api, repo_id, operations, commit_message): """ 统一处理 Hugging Face 的 Commit 操作。 """ try: if operations: api.create_commit( repo_id=repo_id, repo_type="dataset", operations=operations, commit_message=commit_message, ) log_print("已成功更新云端版本!") else: log_print("当前版本已为最新,无需更新!") except Exception as e: log_print(f"更新提交失败: {str(e)}") def download_files(api, repo_id, remote_dir, local_dir): """ 从远程仓库下载文件到本地目录。 """ try: remote_files = api.list_repo_files(repo_id=repo_id, repo_type="dataset") filtered_files = [file for file in remote_files if file.startswith(remote_dir)] for remote_file in filtered_files: relative_path = os.path.relpath(remote_file, remote_dir) local_file_path = os.path.join(local_dir, relative_path) os.makedirs(os.path.dirname(local_file_path), exist_ok=True) tmp_filepath = api.hf_hub_download( repo_id=repo_id, filename=remote_file, repo_type="dataset", ) if tmp_filepath and os.path.exists(tmp_filepath): shutil.copy2(tmp_filepath, local_file_path) shutil.copy2(tmp_filepath, f"/tmp/{remote_file}") log_print(f"已下载{remote_file} -> {local_file_path}") except Exception as e: log_print(f"下载失败: {str(e)}") def sync_repository(api, repo_id, remote_dir, whitelist=None, blacklist=None): """ 同步本地与远程仓库(单次执行)。 """ log_print(f"开始数据同步进程...") source_dir = "./data" target_dir = f"/tmp/{remote_dir}" files_to_upload = compare_and_sync_directories(source_dir, target_dir, whitelist, blacklist) remote_files = api.list_repo_files(repo_id=repo_id, repo_type="dataset") local_files = [] if os.path.exists(target_dir): for root, _, files in os.walk(target_dir): for file in files: file_path = os.path.join(root, file) relative_path = os.path.relpath(file_path, target_dir) local_files.append(relative_path) local_files_set = set(local_files) files_to_delete = [ os.path.relpath(remote_file, remote_dir) for remote_file in remote_files if remote_file.startswith(remote_dir) and os.path.relpath(remote_file, remote_dir) not in local_files_set ] operations = [] upload_files(api, repo_id, source_dir, remote_dir, files_to_upload, operations) delete_files(api, repo_id, remote_dir, files_to_delete, operations) commit_operations(api, repo_id, operations, f"Sync repository: {remote_dir}") if files_to_upload: log_print(f"文件已上传: {files_to_upload}") if files_to_delete: log_print(f"文件已删除: {files_to_delete}") if __name__ == "__main__": action = sys.argv[1] token = sys.argv[2] repo_id = sys.argv[3] remote_dir = sys.argv[4] api = HfApi(token=token) source_dir = "./data" if action == "sync": whitelist = sys.argv[5].split(",") if len(sys.argv) > 5 and sys.argv[5] not in ["", "None"] else None blacklist = sys.argv[6].split(",") if len(sys.argv) > 6 and sys.argv[6] not in ["", "None"] else None sync_repository(api, repo_id, remote_dir, whitelist, blacklist) elif action == "download": download_files(api, repo_id, remote_dir, source_dir) EOL sync_data() { SYNC_INTERVAL=${SYNC_INTERVAL:-7200} # 默认同步间隔为 7200 秒(2 小时) while true; do python3 /tmp/hf_sync.py sync "${HF_TOKEN}" "${DATASET_ID}" "${APP_NAME}" "${WHITELIST}" "${BLACKLIST}" sleep "${SYNC_INTERVAL}" done } python3 /tmp/hf_sync.py download "${HF_TOKEN}" "${DATASET_ID}" "${APP_NAME}" sync_data &